Skip to content

KV Cache: A Caching Mechanism To Accelerate Transformer Generation

Last Updated on 2024-11-01 by Clay

During the decoding process of large language models, especially in Auto-regressive models, decoding must be performed step-by-step until the entire sequence is generated. Within this process, there are caching techniques that can help reduce computation and improve decoding speed; one such technique is known as the KV Cache.

This technique is extremely important, and I have always wanted to write a detailed note about it. Perhaps in the future, when this concept feels unfamiliar to me, I can revisit it and quickly refresh the memories from my initial learning experience.


Quick Review of the Self-Attention Mechanism

The self-attention mechanism allows tokens to integrate information from various positions across a sequence. Once our input sequence passes through the embedding layer and converts token IDs to vectors, it takes the shape of (batch_size, seq_len, hidden_size). It is then divided into multiple heads (multi-head), changing the shape to (batch_size, head_num, seq_len, head_size), where head_size is essentially hidden_size / head_num.

The main focus is on seq_len and head_size, often abbreviated as M x N in many notes. Through linear transformations, we derive Q, K, and V matrices with the shape (seq_len, head_size), and then calculate using the following formula:

Softmax(\frac{QK^T}{\sqrt[]{d} } V)

This computation is repeated multiple times within the attention layers of the Transformer architecture.


What is KV Cache?

In a nutshell, KV Cache is a type of cache that helps reduce redundant computation. Refer to the image below:

Do you see the Key Cache and Value Cache in the image? In fact, the seq_len - 1 keys and values obtained from linear transformations are redundant calculations. Tokens in previous sequences do not change as the sequence grows during auto-regressive generation of the next token. By caching their keys and values, we reduce the amount of computation for linear transformations as well as intermediate hidden states in QKV computations.

Is it only the reduction of Key and Value computations for current sequence length - 1? Not quite, because these attention layers are often stacked in multiple layers within the model. Assuming there are 12 layers, we save the computation for current sequence length - 1 across all 12 layers.

For further insights, you can refer to my past GPT-2 implementation based on the HuggingFace Transformers architecture (https://github.com/ccs96307/gpt2-pytorch-implemented/tree/main).

In this implementation, I defined the GPT2Attention and GPT2MLP modules, which are then encapsulated by the GPT2Block module and repeated for config.n_layer layers. If use_cache is enabled, each GPT2Attention module will return a present cache containing (k, v), which can be collected at the outermost level and reused during decoding.

class GPT2Attention(torch.nn.Module):
    def __init__(self, config: GPT2Config) -> torch.Tensor:
        super().__init__()
        # Init
        self.n_head = config.n_head
        self.head_size = int(config.hidden_size / self.n_head)
        self.scale = 1 / (self.head_size ** 0.5)
        self.hidden_size = config.hidden_size

        self.c_attn = Conv1D(input_dim=config.hidden_size, output_dim=3*config.hidden_size)
        self.c_proj = Conv1D(input_dim=config.hidden_size, output_dim=config.hidden_size)
        self.attn_dropout = torch.nn.Dropout(p=config.attn_pdrop)
        self.resid_dropout = torch.nn.Dropout(p=config.resid_pdrop)

    def forward(
        self,
        hidden_states: torch.Tensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> torch.Tensor:
        batch_size = hidden_states.size(0)

        # QKV
        qkv = self.c_attn(hidden_states)
        q, k, v = qkv.split(self.hidden_size, dim=-1)

        # Reshape
        q = q.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
        k = k.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
        v = v.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)

        if layer_past is not None:
            past_key, past_value = layer_past
            k = torch.cat((past_key, k), dim=-2)
            v = torch.cat((past_value, v), dim=-2)
        
        if use_cache:
            present = (k, v)
        else:
            present = None

        # Compute Q @ K^T
        attention_scores = torch.matmul(q, k.transpose(-1, -2))
        attention_scores = attention_scores * self.scale

        # Causal mask
        seq_len = hidden_states.size(-2)
        mask_value = torch.finfo(hidden_states.dtype).min
        causal_mask = torch.triu(torch.full((seq_len, seq_len), mask_value), diagonal=1)
        attention_scores = attention_scores + causal_mask

        # Attention mask
        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.to(dtype=hidden_states.dtype)
            attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
            attention_scores = attention_scores + attention_mask

        # Softmax
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.attn_dropout(attention_weights)

        # Compute V
        attention_scores = torch.matmul(attention_weights, v)

        # Reshape
        context_layer = attention_scores.permute(0, 2, 1, 3).contiguous()
        context_layer = context_layer.view(batch_size, -1, self.head_size * self.n_head)
        attention_output = self.c_proj(context_layer)

        # Skip connection & Dropout
        attention_output = self.resid_dropout(attention_output)

        outputs = (attention_output, present)
        if output_attentions:
            outputs += (attention_weights,)
        
        return outputs


class GPT2MLP(torch.nn.Module):
    def __init__(self, inner_dim: int, config: GPT2Config) -> None:
        super().__init__()
        self.c_fc = Conv1D(input_dim=config.hidden_size, output_dim=inner_dim)
        self.c_proj = Conv1D(input_dim=inner_dim, output_dim=config.hidden_size)
        self.act = NewGELUActivation()
        self.dropout = torch.nn.Dropout(p=config.resid_pdrop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x        


class GPT2Block(torch.nn.Module):
    def __init__(self, config: GPT2Config) -> None:
        super().__init__()
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size

        self.ln_1 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config=config)
        self.ln_2 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(inner_dim, config=config)

    def forward(
        self,
        hidden_states: torch.LongTensor,
        layer_past: Optional[Tuple[torch.LongTensor]] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states

        # Self-Attention
        hidden_states = self.ln_1(hidden_states)
        attention_outputs = self.attn(
            hidden_states=hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        attention_output = attention_outputs[0]  # output_attn: (attention_output, present, all_attentions)
        outputs = attention_outputs[1:]
        
        # Residual connection
        hidden_states = attention_output + residual
        residual = hidden_states

        # MLP
        hidden_states = self.ln_2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = hidden_states + residual

        # Cache
        if use_cache:
            outputs = (hidden_states,) + outputs  # outputs: (hidden_states, present, all_attentions)
        else:
            outputs = (hidden_states,) + outputs  # outputs: (hidden_states, all_attentions)

        return outputs
    

class GPT2Model(torch.nn.Module):
    def __init__(self, config: GPT2Config) -> None:
        super().__init__()
        self.wte = torch.nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = torch.nn.Embedding(config.n_positions, config.n_embd)
        self.dropout = torch.nn.Dropout(p=config.embd_pdrop)
        self.h = torch.nn.ModuleList([GPT2Block(config=config) for _ in range(config.n_layer)])
        self.ln_f = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> BaseModelOutputWithPastAndCrossAttentions:
        # Token embeddings
        token_embeddings = self.wte(input_ids)

        # Position embeddings
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.size(1))
            position_embeddings = self.wpe(position_ids).view(1, -1, token_embeddings.size(-1))
        else:
            position_embeddings = self.wpe(position_ids)

        # Sum the embeddings
        embeddings = token_embeddings + position_embeddings

        # KV Cache
        if past_key_values is None:
            past_key_values = tuple([None] * len(self.h))

        # Computation
        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
        
        hidden_states = self.dropout(embeddings)

        for block, layer_past in zip(self.h, past_key_values):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states)

            outputs = block(
                hidden_states=hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

            hidden_states = outputs[0]

            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache is True else 1],)
        
        # LayerNorm
        hidden_states = self.ln_f(hidden_states)

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=None,
        )


The data saved by the KV Cache can be thought of in simple terms: If a sequence has a length of 10 units, we save computations for 9 units. If the sequence length is 1,000,000,000, we save computations for 999,999,999 units!

Can you now see the significance of the KV Cache?


References


Read More

Leave a Reply