Last Updated on 2024-10-30 by Clay
在大型語言模型的解碼過程中,尤其是自迴歸模型(Auto-regressive model),勢必得一次次地解碼直到生成整個序列為止,在這之中存在著一些 cache 的技巧,能夠幫助模型減少計算量、提昇解碼速度;而這個技巧就被稱為 KV Cache。
而這個技巧非常重要,我一直想要寫一篇仔細的筆記,等將來我對這個觀念逐漸陌生之時,或許拿出來讀上一讀,就可以快速回復當初學習時的記憶也說不定呢。
快速複習自注意力機制
自注意力機制是一種讓 Token 能夠關注全局不同 Token 的資訊整合機制,我們的輸入序列在通過 embedding layer 把 Token ID 轉換成向量後,會呈現 (batch_size, seq_len, hidden_size)
的形狀(shape),這時會切成多頭(multi-head),而變成 (batch_size, head_num, seq_len, head_size)
—— head_size
其實也就是 hidden_size / head_num
。
而我們真正需要關注的其實是最後 seq_len
和 head_size
,許多筆記都會簡單紀錄為 M x N。我們分別通過 W_Q、W_K、W_V 三個權重矩陣線性變換成形狀為 (seq_len, head_size)
的 Q、K、V,接著透過以下公式計算:
我們就會得到最終的結果。此一計算過程會在 Transformer 架構中的注意力層中不斷地被反覆計算。
什麼是 KV Cache?
一言以蔽之,KV Cache 是一種暫存(cache),可以幫助我們減少一些重複性的計算,可以參考下方的圖示:
大家有看到圖中的 Key Cache
和 Value Cache
嗎?事實上,透過線性變換得到的 seq_len - 1
個 Key、Value 都是多計算的,因為過往序列中的 Token 並不會隨著序列變長(我們持續解碼下一個 Token 做自迴歸生成) 而改變其狀態,所以將其 Key 和 Value 暫存起來,就可以減少線性變換的計算量、以及 QKV 計算中的一些中間隱藏狀態。
以為只是減少『當前序列長度 - 1』的 Key、Value 計算量嗎?其實不只,因為模型中會堆疊許多層的注意力層,假設有 12 層,我們就是直接省下了 12 層的『當前序列長度 - 1』計算量。
這部份也可以參考我以前參考 HuggingFace Transformers 架構的 GPT-2 實作(https://github.com/ccs96307/gpt2-pytorch-implemented/tree/main)。
可以看到我定義了 GPT2Attention
和 GPT2MLP
模組,然後這兩個模組再被 GPT2Block
包裝起來並重複了 config.n_layer
層;並且只要 use_cache
有啟用,我們就會在每一層 GPT2Attention
的最後,都返回包裝了 (k, v)
的 present
暫存,只要在最外層收集起來,每一次要解碼時重新傳入即可。
class GPT2Attention(torch.nn.Module):
def __init__(self, config: GPT2Config) -> torch.Tensor:
super().__init__()
# Init
self.n_head = config.n_head
self.head_size = int(config.hidden_size / self.n_head)
self.scale = 1 / (self.head_size ** 0.5)
self.hidden_size = config.hidden_size
self.c_attn = Conv1D(input_dim=config.hidden_size, output_dim=3*config.hidden_size)
self.c_proj = Conv1D(input_dim=config.hidden_size, output_dim=config.hidden_size)
self.attn_dropout = torch.nn.Dropout(p=config.attn_pdrop)
self.resid_dropout = torch.nn.Dropout(p=config.resid_pdrop)
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
# QKV
qkv = self.c_attn(hidden_states)
q, k, v = qkv.split(self.hidden_size, dim=-1)
# Reshape
q = q.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
k = k.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
v = v.contiguous().view(batch_size, -1, self.n_head, self.head_size).permute(0, 2, 1, 3)
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
if use_cache:
present = (k, v)
else:
present = None
# Compute Q @ K^T
attention_scores = torch.matmul(q, k.transpose(-1, -2))
attention_scores = attention_scores * self.scale
# Causal mask
seq_len = hidden_states.size(-2)
mask_value = torch.finfo(hidden_states.dtype).min
causal_mask = torch.triu(torch.full((seq_len, seq_len), mask_value), diagonal=1)
attention_scores = attention_scores + causal_mask
# Attention mask
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=hidden_states.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
attention_scores = attention_scores + attention_mask
# Softmax
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights = self.attn_dropout(attention_weights)
# Compute V
attention_scores = torch.matmul(attention_weights, v)
# Reshape
context_layer = attention_scores.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(batch_size, -1, self.head_size * self.n_head)
attention_output = self.c_proj(context_layer)
# Skip connection & Dropout
attention_output = self.resid_dropout(attention_output)
outputs = (attention_output, present)
if output_attentions:
outputs += (attention_weights,)
return outputs
class GPT2MLP(torch.nn.Module):
def __init__(self, inner_dim: int, config: GPT2Config) -> None:
super().__init__()
self.c_fc = Conv1D(input_dim=config.hidden_size, output_dim=inner_dim)
self.c_proj = Conv1D(input_dim=inner_dim, output_dim=config.hidden_size)
self.act = NewGELUActivation()
self.dropout = torch.nn.Dropout(p=config.resid_pdrop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class GPT2Block(torch.nn.Module):
def __init__(self, config: GPT2Config) -> None:
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
self.ln_1 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config=config)
self.ln_2 = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config=config)
def forward(
self,
hidden_states: torch.LongTensor,
layer_past: Optional[Tuple[torch.LongTensor]] = None,
attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
# Self-Attention
hidden_states = self.ln_1(hidden_states)
attention_outputs = self.attn(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attention_outputs[0] # output_attn: (attention_output, present, all_attentions)
outputs = attention_outputs[1:]
# Residual connection
hidden_states = attention_output + residual
residual = hidden_states
# MLP
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
# Cache
if use_cache:
outputs = (hidden_states,) + outputs # outputs: (hidden_states, present, all_attentions)
else:
outputs = (hidden_states,) + outputs # outputs: (hidden_states, all_attentions)
return outputs
class GPT2Model(torch.nn.Module):
def __init__(self, config: GPT2Config) -> None:
super().__init__()
self.wte = torch.nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = torch.nn.Embedding(config.n_positions, config.n_embd)
self.dropout = torch.nn.Dropout(p=config.embd_pdrop)
self.h = torch.nn.ModuleList([GPT2Block(config=config) for _ in range(config.n_layer)])
self.ln_f = torch.nn.LayerNorm(normalized_shape=config.n_embd, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPastAndCrossAttentions:
# Token embeddings
token_embeddings = self.wte(input_ids)
# Position embeddings
if position_ids is None:
position_ids = torch.arange(0, input_ids.size(1))
position_embeddings = self.wpe(position_ids).view(1, -1, token_embeddings.size(-1))
else:
position_embeddings = self.wpe(position_ids)
# Sum the embeddings
embeddings = token_embeddings + position_embeddings
# KV Cache
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Computation
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
hidden_states = self.dropout(embeddings)
for block, layer_past in zip(self.h, past_key_values):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states)
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache is True else 1],)
# LayerNorm
hidden_states = self.ln_f(hidden_states)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=None,
)
KV Cache 省下的計算資料,換個角度想其實也非常直白:今天如果序列長度只有 10 個單位,我們就是省下了 9 個單位的計算量 —— 如果序列長度有 1,000,000,000,我們就是省下了 999,999,999 個單位的計算量!
感覺到 KV Cache 的重要性了嗎?
References
- Transformers KV Caching Explained | by João Lages
- Unlocking Longer Generation with Key-Value Cache ...