Last Updated on 2024-09-08 by Clay
Introduction
Triton is an open-source GPU programming language compiler released by OpenAI in 2021. Over recent years, it has become increasingly popular among developers for writing and optimizing parallel programs on GPUs. Compared to traditional libraries such as CUDA or OpenCL, Triton offers a Python-like syntax, making it more readable and easier to learn.
Triton has the following features:
- Simplified Memory Access: Triton abstracts a high-performance memory access interface, allowing developers to load/store data through high-level operations without manually managing complex memory allocations.
- Automatic Performance Optimization: By automatically adjusting the execution and data throughput of the kernel, Triton reduces the need for manual optimization.
- Python Integration: Triton is a Python library that can be integrated directly with Python code, making the process of optimizing GPU acceleration more straightforward.
In the following, I will document how to perform the first tutorial from Triton's official documentation and include any additional resources I had to research for better understanding. Hopefully, this will help those who want to learn Triton, as well as serve as a future reference for myself.
Installation
For those of you learning Triton, I assume you are already advanced developers, so I won't go into unnecessary details. Please install packages within a Python virtual environment to avoid affecting the native system.
pip install triton
If you wish to install the latest preview version or build from source, please refer to the references at the end of this post.
Example Program: Vector Addition
We will learn the following:
- How to use the triton.jit decorator to define a basic Triton kernel
- How to verify and use Triton Benchmark to assess the performance of a Triton kernel
Once we import the necessary modules, we will define a Triton kernel using the @triton.jit
decorator (jit stands for just-in-time compilation).
Here, x_ptr
, y_ptr
, and output_ptr
are all pointers. In Triton kernels, pointers are passed instead of objects. A pointer refers to the memory address of a variable, allowing direct access and manipulation. Additionally, output
is an empty array we need to predefine, which will store the result of adding x
and y
.
The pid retrieves the current program ID, and a program is the basic execution unit in Triton. Here, we use axis=0
to represent a 1D launch grid.
block_start = pid * BLOCK_SIZE
computes the starting point of the data block that the current program is processing.
mask = offsets < n_elements
creates a mask to prevent accessing elements beyond the end of the vector. This is because the data we're processing isn't always perfectly divisible by the block size.
x = tl.load(x_ptr + offsets, mask=mask)
and y = tl.load(y_ptr + offsets, mask=mask)
use the mask mask
to load the vectors.
tl.store(output_ptr + offsets, output, mask=mask)
stores the sum of x
and y
at the memory address specified by output_ptr
, using offsets
and mask
to ensure only valid elements are updated.
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
):
# There are multiple 'programs' processing different data. We identify which program
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# Would each access the elements [0:64, 64:128, 128:192, 192:256].
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
After defining the Triton kernel for adding two vectors, now we will create the function add()
to invoke the add_kernel()
we just defined:
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# In this case, we use a 1D grid where the size is the number of blocks
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
One confusing part is why we need to define grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
as a lambda function. This is because add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
invokes the add_kernel
defined earlier, which requires a function to configure the kernel's launch parameters. Hence, we need a grid
function, and its grid size is based on the number of elements each block should process, which is why we use triton.cdiv()
to calculate how many elements a block should handle (ensuring the result is always an integer; for example, triton.cdiv(100, 3)
divides 100 elements among 3 blocks, and the result will be 34 instead of 33.333333).
Let’s compare the results with PyTorch’s native addition function to see if they are consistent.
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device="cuda")
y = torch.rand(size, device="cuda")
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
Output:
tensor([1.3713, 1.3076, 0.4940, ..., 0.9592, 0.3409, 1.2567], device='cuda:0') tensor([1.3713, 1.3076, 0.4940, ..., 0.9592, 0.3409, 1.2567], device='cuda:0') The maximum difference between torch and triton is 0.0
No difference at all!
That’s it, our first simple Triton kernel is complete. It really aligns with Python developers' preferences—simple and easy to understand.
Lastly, let’s perform Triton’s benchmark test. Triton integrates visual graphs and tabular output. This method will be used frequently, often in comparison with other optimization techniques.
But today, we are only comparing Triton with native PyTorch.
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["size"],
x_vals=[2**i for i in range(12, 28)],
x_log=True, # x axis is logarithmic
line_arg="provider",
line_names=["Triton", "Torch"],
line_vals=["triton", "torch"],
styles=[("blue", "-"), ("green", "-")], # Line styles
ylabel="GB/s", # Label name for the y-axis
plot_name="vector-add-performance",
args={},
)
)
def benchmark(size, provider):
x = torch.rand(size, device="cuda", dtype=torch.float32)
y = torch.rand(size, device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
gbps = lambda ms: 12 * size / ms * 1e-6
return list(map(gbps, [ms, min_ms, max_ms]))
benchmark.run(print_data=True, show_plots=True)
Output:
vector-add-performance: size Triton Torch 0 4096.0 8.000000 9.600000 1 8192.0 18.177514 16.786886 2 16384.0 31.999999 31.999999 3 32768.0 56.237985 63.503880 4 65536.0 76.800002 76.800002 5 131072.0 109.714284 109.714284 6 262144.0 165.913932 134.111870 7 524288.0 175.542852 175.542852 8 1048576.0 195.047621 198.193551 9 2097152.0 201.442627 203.107443 10 4194304.0 214.637557 215.578943 11 8388608.0 223.418180 222.250103 12 16777216.0 225.366931 223.672355 13 33554432.0 229.682243 228.066990 14 67108864.0 230.828288 224.566540 15 134217728.0 228.096453 224.454365