Skip to content

[PyTorch] 使用 2.0+ 的 SDPA 提昇 Transformer 自注意力機制計算速度

Last Updated on 2024-03-25 by Clay

SDPA 介紹

縮放點積注意力Scaled Dot-Product Attention, SDPA)對於熟悉 Transformer 自注意力架構(Self-Attention)的人來說,恐怕馬上腦海中瞬間就閃過了:

Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V


而在 PyTorch 2.0+ 以後,在 torch.nn.functional 中有了一個全新的高效計算函式 torch.nn.functional.scaled_dot_product_attention

該函式底層是高性能的 kernels,讓我們可以直接改換現有 Transformer 架構中的自注意力機制,讓其計算效率更高。實際上根據官方文件的說明,當前 PyTorch 底層有著三種支援的實現:

  1. FlashAttention
  2. Memory-Efficient Attention(xformers)
  3. PyTorch 的 C++ 底層實現

如果當前運行程式的環境,其後端為 CUDA,則會自動調用優化的實現;若是其他後端,則會啟用 PyTorch 實現。

最近我經常閱讀 HuggingFace transformers 套件的原始碼,發現如 Llama、Mistral、Gemma 等模型實現,也在除 FlashAttention2 支援的情況以外,通通使用了 torch.nn.functional.scaled_dot_product_attention 加速運算 —— 這在舊版一點的 transformers 套件中似乎是沒有的,至少我去年(2023)十月閱讀 Mistral 的實現時沒有看到。

那麼,以下為我測試該實現與自己實現的架構速度評估,使用到了 torch.benchmark 確保公平性。


確認結果一致

首先,我得確認自己的實現與官方的結果一致。

以下分別是我的實現與 torch.nn.functional.scaled_dot_product_attention 的輸出結果比較:

import math
import torch


def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
    return torch.nn.functional.scaled_dot_product_attention(
        query=query,
        key=key,
        value=value,
    )


def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: int) -> torch.Tensor:
    # Attention weights (Q * K^T)
    attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)

    # Upcast attention to fp32
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)

    # Attention output (A = Q * K^T, A * V)
    attention_output = torch.matmul(attention_weights, value)
    return attention_output


def main() -> None:
    # Init
    batch_size = 4
    num_heads = 2
    seq_len = 100
    head_size = 32

    # QKV
    query = torch.rand(batch_size, num_heads, seq_len, head_size)
    key = torch.rand(batch_size, num_heads, seq_len, head_size)
    value = torch.rand(batch_size, num_heads, seq_len, head_size)

    # SDPA output
    torch_sdpa_output = torch_sdpa(query=query, key=key, value=value)
    my_sdpa_output = my_sdpa(query=query, key=key, value=value, head_size=head_size)
    
    # Comparison
    print(torch.allclose(torch_sdpa_output, my_sdpa_output, atol=1e-10))


if __name__ == "__main__":
    main()


Output:

True


不過由於底層實現的小數點邊際值可能跟 PyTorch 直接寫的有誤差,所以我需要設定小數位數在 10 位以下,我的實現才會與 torch.nn.functional.scaled_dot_product_attention 結果保持一致。


計算效率評估

import torch
import torch.utils.benchmark as benchmark
import math


# Init
batch_size = 4
num_heads = 2
seq_len = 100
head_size = 32

# QKV
query = torch.rand(batch_size, num_heads, seq_len, head_size)
key = torch.rand(batch_size, num_heads, seq_len, head_size)
value = torch.rand(batch_size, num_heads, seq_len, head_size)


def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: float) -> torch.Tensor:
    # Attention weights (Q * K^T)
    attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)

    # Upcast attention to fp32
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)

    # Attention output (A = Q * K^T, A * V)
    attention_output = torch.matmul(attention_weights, value)
    
    return attention_output


def benchmark_my_sdpa():
    return my_sdpa(query, key, value, head_size)


def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.scaled_dot_product_attention(
        query=query,
        key=key,
        value=value,
    )


def benchmark_torch_sdpa():
    return torch_sdpa(query, key, value)


# Testing
my_test = benchmark.Timer(
    stmt="benchmark_my_sdpa()",
    setup="from __main__ import benchmark_my_sdpa",
    globals=globals(),
    num_threads=torch.get_num_threads(),
)

torch_test = benchmark.Timer(
    stmt="benchmark_torch_sdpa()",
    setup="from __main__ import benchmark_torch_sdpa",
    globals=globals(),
    num_threads=torch.get_num_threads(),
)

# Result
print("My Implementation:", my_test.timeit(100000))
print("Torch Implementation:", torch_test.timeit(100000))


Output:

My Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_my_sdpa()
setup: from __main__ import benchmark_my_sdpa
  142.33 us
  1 measurement, 100000 runs , 8 threads
Torch Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_torch_sdpa()
setup: from __main__ import benchmark_torch_sdpa
  88.38 us
  1 measurement, 100000 runs , 8 threads


簡直是完敗!不過這也是理所當然,畢竟 PyTorch 底層經過了大量的優化,並不是隨便寫寫就能超越的。

不過經此測試後,我理解到了以後若是要實現 SDPA 的過程,可以盡量依靠 PyTorch 實現;有趣的是,當我測試 torch.nn.MultiheadAttention 時,卻是我的實現獲勝了。閱讀原始馬後,我只能理解為,應該是因為在實現 torch.nn.MultiheadAttention 時考慮到了通用性,做了太多的情況判斷式。


References


Read More

Leave a Reply