Skip to content

[PyTorch] Using SDPA in 2.0+ to Improve the Computation Speed of Transformer’s Self-Attention Mechanism

Last Updated on 2024-08-16 by Clay

SDPA Introduction

Scaled Dot-Product Attention (SDPA) might immediately pop into the minds of those familiar with the Transformer self-attention mechanism:

Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V


In PyTorch 2.0+, a new efficient computation function torch.nn.functional.scaled_dot_product_attention is available in torch.nn.functional.

This function is backed by high-performance kernels, allowing us to directly replace the self-attention mechanism in existing Transformer architectures for improved computational efficiency. According to the official documentation, PyTorch currently supports three implementations:

  1. FlashAttention
  2. Memory-Efficient Attention (xformers)
  3. PyTorch's C++ backend implementation

If the current environment's backend is CUDA, it will automatically call the optimized implementation; otherwise, it will fall back to PyTorch's default implementation.

Recently, I've been reading the source code of HuggingFace's transformers library and discovered that models such as Llama, Mistral, and Gemma use torch.nn.functional.scaled_dot_product_attention for speed-ups, except when FlashAttention2 is supported. This was not present in slightly older versions of the transformers library, at least when I read the Mistral implementation in October 2023.

Now, here's my speed evaluation comparing this function to my own implementation, ensuring fairness using torch.benchmark.


Ensuring Consistency

First, I need to verify that my implementation gives the same result as the official one.

Below is a comparison of the output between my implementation and torch.nn.functional.scaled_dot_product_attention:

import math
import torch


def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
    return torch.nn.functional.scaled_dot_product_attention(
        query=query,
        key=key,
        value=value,
    )


def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: int) -> torch.Tensor:
    # Attention weights (Q * K^T)
    attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)

    # Upcast attention to fp32
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)

    # Attention output (A = Q * K^T, A * V)
    attention_output = torch.matmul(attention_weights, value)
    return attention_output


def main() -> None:
    # Init
    batch_size = 4
    num_heads = 2
    seq_len = 100
    head_size = 32

    # QKV
    query = torch.rand(batch_size, num_heads, seq_len, head_size)
    key = torch.rand(batch_size, num_heads, seq_len, head_size)
    value = torch.rand(batch_size, num_heads, seq_len, head_size)

    # SDPA output
    torch_sdpa_output = torch_sdpa(query=query, key=key, value=value)
    my_sdpa_output = my_sdpa(query=query, key=key, value=value, head_size=head_size)
    
    # Comparison
    print(torch.allclose(torch_sdpa_output, my_sdpa_output, atol=1e-10))


if __name__ == "__main__":
    main()


Output:

True


However, due to marginal floating-point differences between PyTorch's underlying implementation and mine, I needed to set the precision to 10 decimal places for the results to be consistent.


Performance Evaluation

import torch
import torch.utils.benchmark as benchmark
import math


# Init
batch_size = 4
num_heads = 2
seq_len = 100
head_size = 32

# QKV
query = torch.rand(batch_size, num_heads, seq_len, head_size)
key = torch.rand(batch_size, num_heads, seq_len, head_size)
value = torch.rand(batch_size, num_heads, seq_len, head_size)


def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: float) -> torch.Tensor:
    # Attention weights (Q * K^T)
    attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)

    # Upcast attention to fp32
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)

    # Attention output (A = Q * K^T, A * V)
    attention_output = torch.matmul(attention_weights, value)
    
    return attention_output


def benchmark_my_sdpa():
    return my_sdpa(query, key, value, head_size)


def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.scaled_dot_product_attention(
        query=query,
        key=key,
        value=value,
    )


def benchmark_torch_sdpa():
    return torch_sdpa(query, key, value)


# Testing
my_test = benchmark.Timer(
    stmt="benchmark_my_sdpa()",
    setup="from __main__ import benchmark_my_sdpa",
    globals=globals(),
    num_threads=torch.get_num_threads(),
)

torch_test = benchmark.Timer(
    stmt="benchmark_torch_sdpa()",
    setup="from __main__ import benchmark_torch_sdpa",
    globals=globals(),
    num_threads=torch.get_num_threads(),
)

# Result
print("My Implementation:", my_test.timeit(100000))
print("Torch Implementation:", torch_test.timeit(100000))


Output:

My Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_my_sdpa()
setup: from __main__ import benchmark_my_sdpa
  142.33 us
  1 measurement, 100000 runs , 8 threads
Torch Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7f8ca3a18c10>
benchmark_torch_sdpa()
setup: from __main__ import benchmark_torch_sdpa
  88.38 us
  1 measurement, 100000 runs , 8 threads


Completely outperformed! But this is to be expected since PyTorch's backend is highly optimized and not something that can be easily outdone.

Interestingly, when I tested torch.nn.MultiheadAttention, my implementation actually outperformed it. After reading the source code, I concluded that this might be due to the consideration of many general-use cases, leading to too many conditional checks in the torch.nn.MultiheadAttention implementation.


References


Read More

Leave a Reply