Last Updated on 2024-11-01 by Clay
During the decoding process of large language models, especially in Auto-regressive models, decoding must be performed step-by-step until the entire sequence is generated. Within this process, there are caching techniques that can help reduce computation and improve decoding speed; one such technique is known as the KV Cache.
This technique is extremely important, and I have always wanted to write a detailed note about it. Perhaps in the future, when this concept feels unfamiliar to me, I can revisit it and quickly refresh the memories from my initial learning experience.
Quick Review of the Self-Attention Mechanism
The self-attention mechanism allows tokens to integrate information from various positions across a sequence. Once our input sequence passes through the embedding layer and converts token IDs to vectors, it takes the shape of (batch_size, seq_len, hidden_size)
. It is then divided into multiple heads (multi-head), changing the shape to (batch_size, head_num, seq_len, head_size)
, where head_size
is essentially hidden_size / head_num
.
The main focus is on seq_len
and head_size
, often abbreviated as M x N in many notes. Through linear transformations, we derive Q, K, and V matrices with the shape (seq_len, head_size)
, and then calculate using the following formula:
This computation is repeated multiple times within the attention layers of the Transformer architecture.
What is KV Cache?
In a nutshell, KV Cache is a type of cache that helps reduce redundant computation. Refer to the image below:
Do you see the Key Cache
and Value Cache
in the image? In fact, the seq_len - 1
keys and values obtained from linear transformations are redundant calculations. Tokens in previous sequences do not change as the sequence grows during auto-regressive generation of the next token. By caching their keys and values, we reduce the amount of computation for linear transformations as well as intermediate hidden states in QKV computations.
Is it only the reduction of Key and Value computations for current sequence length - 1? Not quite, because these attention layers are often stacked in multiple layers within the model. Assuming there are 12 layers, we save the computation for current sequence length - 1 across all 12 layers.
For further insights, you can refer to my past GPT-2 implementation based on the HuggingFace Transformers architecture (https://github.com/ccs96307/gpt2-pytorch-implemented/tree/main).
In this implementation, I defined the GPT2Attention
and GPT2MLP
modules, which are then encapsulated by the GPT2Block
module and repeated for config.n_layer
layers. If use_cache
is enabled, each GPT2Attention
module will return a present
cache containing (k, v)
, which can be collected at the outermost level and reused during decoding.
class GPT2Attention(torch.nn.Module):
def __init__(self, config: GPT2Config) -> torch.Tensor:
super().__init__()
# Init
self.n_head = config.n_head
self.head_size = int(config.hidden_size / self.n_head)
self.scale = 1 / (self.head_size ** 0.5)
self.hidden_size = config.hidden_size
self.c_attn = Conv1D(input_dim=config.hidden_size, output_dim=3*config.hidden_size)
self.c_proj = Conv1D(input_dim=config.hidden_size, output_dim=config.hidden_size)
self.attn_dropout = torch.nn.Dropout(p=config.attn_pdrop)
self.resid_dropout = torch.nn.Dropout(p=config.resid_pdrop)
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
# QKV
qkv = self.c_attn(hidden_states)
q, k, v = qkv.split(self.hidden_size, dim=-1)
# Reshape
q = q.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
k = k.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
v = v.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
if use_cache:
present = (k, v)
else:
present = None
# Compute Q @ K^T
attention_scores = torch.matmul(q, k.transpose(-1, -2))
attention_scores = attention_scores * self.scale
# Causal mask
seq_len = hidden_states.size(-2)
mask_value = torch.finfo(hidden_states.dtype).min
causal_mask = torch.triu(torch.full((seq_len, seq_len), mask_value), diagonal=1)
attention_scores = attention_scores + causal_mask
# Attention mask
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=hidden_states.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
attention_scores = attention_scores + attention_mask
# Softmax
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights = self.attn_dropout(attention_weights)
# Compute V
attention_scores = torch.matmul(attention_weights, v)
# Reshape
context_layer = attention_scores.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(batch_size, -1, self.head_size * self.n_head)
attention_output = self.c_proj(context_layer)
# Skip connection & Dropout
attention_output = self.resid_dropout(attention_output)
outputs = (attention_output, present)
if output_attentions:
outputs += (attention_weights,)
return outputs
class GPT2MLP(torch.nn.Module):
def __init__(self, inner_dim: int, config: GPT2Config) -> None:
super().__init__()
self.c_fc = Conv1D(input_dim=config.hidden_size, output_dim=inner_dim)
self.c_proj = Conv1D(input_dim=inner_dim, output_dim=config.hidden_size)
self.act = NewGELUActivation()
self.dropout = torch.nn.Dropout(p=config.resid_pdrop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class GPT2Block(torch.nn.Module):
def __init__(self, config: GPT2Config) -> None:
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
self.ln_1 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config=config)
self.ln_2 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config=config)
def forward(
self,
hidden_states: torch.LongTensor,
layer_past: Optional[Tuple[torch.LongTensor]] = None,
attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
# Self-Attention
hidden_states = self.ln_1(hidden_states)
attention_outputs = self.attn(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0] # output_attn: (attention_output, present, all_attentions)
outputs = attention_outputs[1:]
# Residual connection
hidden_states = attention_output + residual
residual = hidden_states
# MLP
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
# Cache
if use_cache:
outputs = (hidden_states,) + outputs # outputs: (hidden_states, present, all_attentions)
else:
outputs = (hidden_states,) + outputs # outputs: (hidden_states, all_attentions)
return outputs
class GPT2Model(torch.nn.Module):
def __init__(self, config: GPT2Config) -> None:
super().__init__()
self.wte = torch.nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = torch.nn.Embedding(config.n_positions, config.n_embd)
self.dropout = torch.nn.Dropout(p=config.embd_pdrop)
self.h = torch.nn.ModuleList([GPT2Block(config=config) for _ in range(config.n_layer)])
self.ln_f = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPastAndCrossAttentions:
# Token embeddings
token_embeddings = self.wte(input_ids)
# Position embeddings
if position_ids is None:
position_ids = torch.arange(0, input_ids.size(1))
position_embeddings = self.wpe(position_ids).view(1, -1, token_embeddings.size(-1))
else:
position_embeddings = self.wpe(position_ids)
# Sum the embeddings
embeddings = token_embeddings + position_embeddings
# KV Cache
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Computation
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
hidden_states = self.dropout(embeddings)
for block, layer_past in zip(self.h, past_key_values):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states)
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache is True else 1],)
# LayerNorm
hidden_states = self.ln_f(hidden_states)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=None,
)
The data saved by the KV Cache can be thought of in simple terms: If a sequence has a length of 10 units, we save computations for 9 units. If the sequence length is 1,000,000,000, we save computations for 999,999,999 units!
Can you now see the significance of the KV Cache?
References
- Transformers KV Caching Explained | by João Lages
- Unlocking Longer Generation with Key-Value Cache ...