Last Updated on 2024-08-16 by Clay
SDPA Introduction
Scaled Dot-Product Attention (SDPA) might immediately pop into the minds of those familiar with the Transformer self-attention mechanism:
In PyTorch 2.0+, a new efficient computation function torch.nn.functional.scaled_dot_product_attention
is available in torch.nn.functional
.
This function is backed by high-performance kernels, allowing us to directly replace the self-attention mechanism in existing Transformer architectures for improved computational efficiency. According to the official documentation, PyTorch currently supports three implementations:
- FlashAttention
- Memory-Efficient Attention (xformers)
- PyTorch's C++ backend implementation
If the current environment's backend is CUDA, it will automatically call the optimized implementation; otherwise, it will fall back to PyTorch's default implementation.
Recently, I've been reading the source code of HuggingFace's transformers library and discovered that models such as Llama, Mistral, and Gemma use torch.nn.functional.scaled_dot_product_attention
for speed-ups, except when FlashAttention2 is supported. This was not present in slightly older versions of the transformers library, at least when I read the Mistral implementation in October 2023.
Now, here's my speed evaluation comparing this function to my own implementation, ensuring fairness using torch.benchmark
.
Ensuring Consistency
First, I need to verify that my implementation gives the same result as the official one.
Below is a comparison of the output between my implementation and torch.nn.functional.scaled_dot_product_attention
:
import math
import torch
def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
)
def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: int) -> torch.Tensor:
# Attention weights (Q * K^T)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)
# Upcast attention to fp32
attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)
# Attention output (A = Q * K^T, A * V)
attention_output = torch.matmul(attention_weights, value)
return attention_output
def main() -> None:
# Init
batch_size = 4
num_heads = 2
seq_len = 100
head_size = 32
# QKV
query = torch.rand(batch_size, num_heads, seq_len, head_size)
key = torch.rand(batch_size, num_heads, seq_len, head_size)
value = torch.rand(batch_size, num_heads, seq_len, head_size)
# SDPA output
torch_sdpa_output = torch_sdpa(query=query, key=key, value=value)
my_sdpa_output = my_sdpa(query=query, key=key, value=value, head_size=head_size)
# Comparison
print(torch.allclose(torch_sdpa_output, my_sdpa_output, atol=1e-10))
if __name__ == "__main__":
main()
Output:
True
However, due to marginal floating-point differences between PyTorch's underlying implementation and mine, I needed to set the precision to 10 decimal places for the results to be consistent.
Performance Evaluation
import torch
import torch.utils.benchmark as benchmark
import math
# Init
batch_size = 4
num_heads = 2
seq_len = 100
head_size = 32
# QKV
query = torch.rand(batch_size, num_heads, seq_len, head_size)
key = torch.rand(batch_size, num_heads, seq_len, head_size)
value = torch.rand(batch_size, num_heads, seq_len, head_size)
def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: float) -> torch.Tensor:
# Attention weights (Q * K^T)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)
# Upcast attention to fp32
attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)
# Attention output (A = Q * K^T, A * V)
attention_output = torch.matmul(attention_weights, value)
return attention_output
def benchmark_my_sdpa():
return my_sdpa(query, key, value, head_size)
def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
)
def benchmark_torch_sdpa():
return torch_sdpa(query, key, value)
# Testing
my_test = benchmark.Timer(
stmt="benchmark_my_sdpa()",
setup="from __main__ import benchmark_my_sdpa",
globals=globals(),
num_threads=torch.get_num_threads(),
)
torch_test = benchmark.Timer(
stmt="benchmark_torch_sdpa()",
setup="from __main__ import benchmark_torch_sdpa",
globals=globals(),
num_threads=torch.get_num_threads(),
)
# Result
print("My Implementation:", my_test.timeit(100000))
print("Torch Implementation:", torch_test.timeit(100000))
Output:
My Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_my_sdpa()
setup: from __main__ import benchmark_my_sdpa
142.33 us
1 measurement, 100000 runs , 8 threads
Torch Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_torch_sdpa()
setup: from __main__ import benchmark_torch_sdpa
88.38 us
1 measurement, 100000 runs , 8 threads
Completely outperformed! But this is to be expected since PyTorch's backend is highly optimized and not something that can be easily outdone.
Interestingly, when I tested torch.nn.MultiheadAttention
, my implementation actually outperformed it. After reading the source code, I concluded that this might be due to the consideration of many general-use cases, leading to too many conditional checks in the torch.nn.MultiheadAttention
implementation.
References
- (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Memory-Efficient Attention