Skip to content

KV Cache: 一種加速 Transformer 模型生成速度的暫存機制

Last Updated on 2024-10-30 by Clay

在大型語言模型的解碼過程中,尤其是自迴歸模型Auto-regressive model),勢必得一次次地解碼直到生成整個序列為止,在這之中存在著一些 cache 的技巧,能夠幫助模型減少計算量、提昇解碼速度;而這個技巧就被稱為 KV Cache。

而這個技巧非常重要,我一直想要寫一篇仔細的筆記,等將來我對這個觀念逐漸陌生之時,或許拿出來讀上一讀,就可以快速回復當初學習時的記憶也說不定呢。


快速複習自注意力機制

自注意力機制是一種讓 Token 能夠關注全局不同 Token 的資訊整合機制,我們的輸入序列在通過 embedding layer 把 Token ID 轉換成向量後,會呈現 (batch_size, seq_len, hidden_size) 的形狀(shape),這時會切成多頭(multi-head),而變成 (batch_size, head_num, seq_len, head_size) —— head_size 其實也就是 hidden_size / head_num

而我們真正需要關注的其實是最後 seq_lenhead_size,許多筆記都會簡單紀錄為 M x N。我們分別通過 W_Q、W_K、W_V 三個權重矩陣線性變換成形狀為 (seq_len, head_size) 的 Q、K、V,接著透過以下公式計算:

Softmax(\frac{QK^T}{\sqrt[]{d} } V)

我們就會得到最終的結果。此一計算過程會在 Transformer 架構中的注意力層中不斷地被反覆計算。


什麼是 KV Cache?

一言以蔽之,KV Cache 是一種暫存(cache),可以幫助我們減少一些重複性的計算,可以參考下方的圖示:

大家有看到圖中的 Key CacheValue Cache 嗎?事實上,透過線性變換得到的 seq_len - 1 個 Key、Value 都是多計算的,因為過往序列中的 Token 並不會隨著序列變長(我們持續解碼下一個 Token 做自迴歸生成) 而改變其狀態,所以將其 Key 和 Value 暫存起來,就可以減少線性變換的計算量、以及 QKV 計算中的一些中間隱藏狀態。

以為只是減少『當前序列長度 - 1』的 Key、Value 計算量嗎?其實不只,因為模型中會堆疊許多層的注意力層,假設有 12 層,我們就是直接省下了 12 層的『當前序列長度 - 1』計算量。

這部份也可以參考我以前參考 HuggingFace Transformers 架構的 GPT-2 實作(https://github.com/ccs96307/gpt2-pytorch-implemented/tree/main)。

可以看到我定義了 GPT2AttentionGPT2MLP 模組,然後這兩個模組再被 GPT2Block 包裝起來並重複了 config.n_layer 層;並且只要 use_cache 有啟用,我們就會在每一層 GPT2Attention 的最後,都返回包裝了 (k, v)present 暫存,只要在最外層收集起來,每一次要解碼時重新傳入即可。

class GPT2Attention(torch.nn.Module):
    def __init__(self, config: GPT2Config) -> torch.Tensor:
        super().__init__()
        # Init
        self.n_head = config.n_head
        self.head_size = int(config.hidden_size / self.n_head)
        self.scale = 1 / (self.head_size ** 0.5)
        self.hidden_size = config.hidden_size

        self.c_attn = Conv1D(input_dim=config.hidden_size, output_dim=3*config.hidden_size)
        self.c_proj = Conv1D(input_dim=config.hidden_size, output_dim=config.hidden_size)
        self.attn_dropout = torch.nn.Dropout(p=config.attn_pdrop)
        self.resid_dropout = torch.nn.Dropout(p=config.resid_pdrop)

    def forward(
        self,
        hidden_states: torch.Tensor,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> torch.Tensor:
        batch_size = hidden_states.size(0)

        # QKV
        qkv = self.c_attn(hidden_states)
        q, k, v = qkv.split(self.hidden_size, dim=-1)

        # Reshape
        q = q.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
        k = k.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
        v = v.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)

        if layer_past is not None:
            past_key, past_value = layer_past
            k = torch.cat((past_key, k), dim=-2)
            v = torch.cat((past_value, v), dim=-2)
        
        if use_cache:
            present = (k, v)
        else:
            present = None

        # Compute Q @ K^T
        attention_scores = torch.matmul(q, k.transpose(-1, -2))
        attention_scores = attention_scores * self.scale

        # Causal mask
        seq_len = hidden_states.size(-2)
        mask_value = torch.finfo(hidden_states.dtype).min
        causal_mask = torch.triu(torch.full((seq_len, seq_len), mask_value), diagonal=1)
        attention_scores = attention_scores + causal_mask

        # Attention mask
        if attention_mask is not None:
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.to(dtype=hidden_states.dtype)
            attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
            attention_scores = attention_scores + attention_mask

        # Softmax
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.attn_dropout(attention_weights)

        # Compute V
        attention_scores = torch.matmul(attention_weights, v)

        # Reshape
        context_layer = attention_scores.permute(0, 2, 1, 3).contiguous()
        context_layer = context_layer.view(batch_size, -1, self.head_size * self.n_head)
        attention_output = self.c_proj(context_layer)

        # Skip connection & Dropout
        attention_output = self.resid_dropout(attention_output)

        outputs = (attention_output, present)
        if output_attentions:
            outputs += (attention_weights,)
        
        return outputs


class GPT2MLP(torch.nn.Module):
    def __init__(self, inner_dim: int, config: GPT2Config) -> None:
        super().__init__()
        self.c_fc = Conv1D(input_dim=config.hidden_size, output_dim=inner_dim)
        self.c_proj = Conv1D(input_dim=inner_dim, output_dim=config.hidden_size)
        self.act = NewGELUActivation()
        self.dropout = torch.nn.Dropout(p=config.resid_pdrop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x        


class GPT2Block(torch.nn.Module):
    def __init__(self, config: GPT2Config) -> None:
        super().__init__()
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size

        self.ln_1 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config=config)
        self.ln_2 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(inner_dim, config=config)

    def forward(
        self,
        hidden_states: torch.LongTensor,
        layer_past: Optional[Tuple[torch.LongTensor]] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states

        # Self-Attention
        hidden_states = self.ln_1(hidden_states)
        attention_outputs = self.attn(
            hidden_states=hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        attention_output = attention_outputs[0]  # output_attn: (attention_output, present, all_attentions)
        outputs = attention_outputs[1:]
        
        # Residual connection
        hidden_states = attention_output + residual
        residual = hidden_states

        # MLP
        hidden_states = self.ln_2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = hidden_states + residual

        # Cache
        if use_cache:
            outputs = (hidden_states,) + outputs  # outputs: (hidden_states, present, all_attentions)
        else:
            outputs = (hidden_states,) + outputs  # outputs: (hidden_states, all_attentions)

        return outputs
    

class GPT2Model(torch.nn.Module):
    def __init__(self, config: GPT2Config) -> None:
        super().__init__()
        self.wte = torch.nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = torch.nn.Embedding(config.n_positions, config.n_embd)
        self.dropout = torch.nn.Dropout(p=config.embd_pdrop)
        self.h = torch.nn.ModuleList([GPT2Block(config=config) for _ in range(config.n_layer)])
        self.ln_f = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> BaseModelOutputWithPastAndCrossAttentions:
        # Token embeddings
        token_embeddings = self.wte(input_ids)

        # Position embeddings
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.size(1))
            position_embeddings = self.wpe(position_ids).view(1, -1, token_embeddings.size(-1))
        else:
            position_embeddings = self.wpe(position_ids)

        # Sum the embeddings
        embeddings = token_embeddings + position_embeddings

        # KV Cache
        if past_key_values is None:
            past_key_values = tuple([None] * len(self.h))

        # Computation
        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
        
        hidden_states = self.dropout(embeddings)

        for block, layer_past in zip(self.h, past_key_values):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states)

            outputs = block(
                hidden_states=hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

            hidden_states = outputs[0]

            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache is True else 1],)
        
        # LayerNorm
        hidden_states = self.ln_f(hidden_states)

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=None,
        )


KV Cache 省下的計算資料,換個角度想其實也非常直白:今天如果序列長度只有 10 個單位,我們就是省下了 9 個單位的計算量 —— 如果序列長度有 1,000,000,000,我們就是省下了 999,999,999 個單位的計算量!

感覺到 KV Cache 的重要性了嗎?


References


Read More

Leave a Reply