Last Updated on 2024-03-25 by Clay
SDPA 介紹
縮放點積注意力(Scaled Dot-Product Attention, SDPA)對於熟悉 Transformer 自注意力架構(Self-Attention)的人來說,恐怕馬上腦海中瞬間就閃過了:
而在 PyTorch 2.0+ 以後,在 torch.nn.functional
中有了一個全新的高效計算函式 torch.nn.functional.scaled_dot_product_attention
。
該函式底層是高性能的 kernels,讓我們可以直接改換現有 Transformer 架構中的自注意力機制,讓其計算效率更高。實際上根據官方文件的說明,當前 PyTorch 底層有著三種支援的實現:
- FlashAttention
- Memory-Efficient Attention(xformers)
- PyTorch 的 C++ 底層實現
如果當前運行程式的環境,其後端為 CUDA,則會自動調用優化的實現;若是其他後端,則會啟用 PyTorch 實現。
最近我經常閱讀 HuggingFace transformers 套件的原始碼,發現如 Llama、Mistral、Gemma 等模型實現,也在除 FlashAttention2 支援的情況以外,通通使用了 torch.nn.functional.scaled_dot_product_attention
加速運算 —— 這在舊版一點的 transformers 套件中似乎是沒有的,至少我去年(2023)十月閱讀 Mistral 的實現時沒有看到。
那麼,以下為我測試該實現與自己實現的架構速度評估,使用到了 torch.benchmark 確保公平性。
確認結果一致
首先,我得確認自己的實現與官方的結果一致。
以下分別是我的實現與 torch.nn.functional.scaled_dot_product_attention
的輸出結果比較:
import math
import torch
def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
)
def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: int) -> torch.Tensor:
# Attention weights (Q * K^T)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)
# Upcast attention to fp32
attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)
# Attention output (A = Q * K^T, A * V)
attention_output = torch.matmul(attention_weights, value)
return attention_output
def main() -> None:
# Init
batch_size = 4
num_heads = 2
seq_len = 100
head_size = 32
# QKV
query = torch.rand(batch_size, num_heads, seq_len, head_size)
key = torch.rand(batch_size, num_heads, seq_len, head_size)
value = torch.rand(batch_size, num_heads, seq_len, head_size)
# SDPA output
torch_sdpa_output = torch_sdpa(query=query, key=key, value=value)
my_sdpa_output = my_sdpa(query=query, key=key, value=value, head_size=head_size)
# Comparison
print(torch.allclose(torch_sdpa_output, my_sdpa_output, atol=1e-10))
if __name__ == "__main__":
main()
Output:
True
不過由於底層實現的小數點邊際值可能跟 PyTorch 直接寫的有誤差,所以我需要設定小數位數在 10 位以下,我的實現才會與 torch.nn.functional.scaled_dot_product_attention
結果保持一致。
計算效率評估
import torch
import torch.utils.benchmark as benchmark
import math
# Init
batch_size = 4
num_heads = 2
seq_len = 100
head_size = 32
# QKV
query = torch.rand(batch_size, num_heads, seq_len, head_size)
key = torch.rand(batch_size, num_heads, seq_len, head_size)
value = torch.rand(batch_size, num_heads, seq_len, head_size)
def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: float) -> torch.Tensor:
# Attention weights (Q * K^T)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)
# Upcast attention to fp32
attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)
# Attention output (A = Q * K^T, A * V)
attention_output = torch.matmul(attention_weights, value)
return attention_output
def benchmark_my_sdpa():
return my_sdpa(query, key, value, head_size)
def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
)
def benchmark_torch_sdpa():
return torch_sdpa(query, key, value)
# Testing
my_test = benchmark.Timer(
stmt="benchmark_my_sdpa()",
setup="from __main__ import benchmark_my_sdpa",
globals=globals(),
num_threads=torch.get_num_threads(),
)
torch_test = benchmark.Timer(
stmt="benchmark_torch_sdpa()",
setup="from __main__ import benchmark_torch_sdpa",
globals=globals(),
num_threads=torch.get_num_threads(),
)
# Result
print("My Implementation:", my_test.timeit(100000))
print("Torch Implementation:", torch_test.timeit(100000))
Output:
My Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_my_sdpa()
setup: from __main__ import benchmark_my_sdpa
142.33 us
1 measurement, 100000 runs , 8 threads
Torch Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_torch_sdpa()
setup: from __main__ import benchmark_torch_sdpa
88.38 us
1 measurement, 100000 runs , 8 threads
簡直是完敗!不過這也是理所當然,畢竟 PyTorch 底層經過了大量的優化,並不是隨便寫寫就能超越的。
不過經此測試後,我理解到了以後若是要實現 SDPA 的過程,可以盡量依靠 PyTorch 實現;有趣的是,當我測試 torch.nn.MultiheadAttention
時,卻是我的實現獲勝了。閱讀原始馬後,我只能理解為,應該是因為在實現 torch.nn.MultiheadAttention 時考慮到了通用性,做了太多的情況判斷式。
References
- (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Memory-Efficient Attention