Skip to content

OpenAI Triton Note (1): Vector Addition

Last Updated on 2024-09-08 by Clay

Introduction

Triton is an open-source GPU programming language compiler released by OpenAI in 2021. Over recent years, it has become increasingly popular among developers for writing and optimizing parallel programs on GPUs. Compared to traditional libraries such as CUDA or OpenCL, Triton offers a Python-like syntax, making it more readable and easier to learn.

Triton has the following features:

  • Simplified Memory Access: Triton abstracts a high-performance memory access interface, allowing developers to load/store data through high-level operations without manually managing complex memory allocations.
  • Automatic Performance Optimization: By automatically adjusting the execution and data throughput of the kernel, Triton reduces the need for manual optimization.
  • Python Integration: Triton is a Python library that can be integrated directly with Python code, making the process of optimizing GPU acceleration more straightforward.

In the following, I will document how to perform the first tutorial from Triton's official documentation and include any additional resources I had to research for better understanding. Hopefully, this will help those who want to learn Triton, as well as serve as a future reference for myself.


Installation

For those of you learning Triton, I assume you are already advanced developers, so I won't go into unnecessary details. Please install packages within a Python virtual environment to avoid affecting the native system.

pip install triton


If you wish to install the latest preview version or build from source, please refer to the references at the end of this post.


Example Program: Vector Addition

We will learn the following:

  • How to use the triton.jit decorator to define a basic Triton kernel
  • How to verify and use Triton Benchmark to assess the performance of a Triton kernel


Once we import the necessary modules, we will define a Triton kernel using the @triton.jit decorator (jit stands for just-in-time compilation).

Here, x_ptr, y_ptr, and output_ptr are all pointers. In Triton kernels, pointers are passed instead of objects. A pointer refers to the memory address of a variable, allowing direct access and manipulation. Additionally, output is an empty array we need to predefine, which will store the result of adding x and y.

The pid retrieves the current program ID, and a program is the basic execution unit in Triton. Here, we use axis=0 to represent a 1D launch grid.

block_start = pid * BLOCK_SIZE computes the starting point of the data block that the current program is processing.

mask = offsets < n_elements creates a mask to prevent accessing elements beyond the end of the vector. This is because the data we're processing isn't always perfectly divisible by the block size.

x = tl.load(x_ptr + offsets, mask=mask) and y = tl.load(y_ptr + offsets, mask=mask) use the mask mask to load the vectors.

tl.store(output_ptr + offsets, output, mask=mask) stores the sum of x and y at the memory address specified by output_ptr, using offsets and mask to ensure only valid elements are updated.

import torch
import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
):
    # There are multiple 'programs' processing different data. We identify which program
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.

    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # Would each access the elements [0:64, 64:128, 128:192, 192:256].
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements

    # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y

    # Write x + y back to DRAM
    tl.store(output_ptr + offsets, output, mask=mask)


After defining the Triton kernel for adding two vectors, now we will create the function add() to invoke the add_kernel() we just defined:

def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )

    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output


One confusing part is why we need to define grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) as a lambda function. This is because add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) invokes the add_kernel defined earlier, which requires a function to configure the kernel's launch parameters. Hence, we need a grid function, and its grid size is based on the number of elements each block should process, which is why we use triton.cdiv() to calculate how many elements a block should handle (ensuring the result is always an integer; for example, triton.cdiv(100, 3) divides 100 elements among 3 blocks, and the result will be 34 instead of 33.333333).


Let’s compare the results with PyTorch’s native addition function to see if they are consistent.

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device="cuda")
y = torch.rand(size, device="cuda")
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')


Output:

tensor([1.3713, 1.3076, 0.4940,  ..., 0.9592, 0.3409, 1.2567], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.9592, 0.3409, 1.2567], device='cuda:0')
The maximum difference between torch and triton is 0.0

No difference at all!

That’s it, our first simple Triton kernel is complete. It really aligns with Python developers' preferences—simple and easy to understand.


Lastly, let’s perform Triton’s benchmark test. Triton integrates visual graphs and tabular output. This method will be used frequently, often in comparison with other optimization techniques.

But today, we are only comparing Triton with native PyTorch.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["size"],
        x_vals=[2**i for i in range(12, 28)],
        x_log=True,  # x axis is logarithmic
        line_arg="provider",
        line_names=["Triton", "Torch"],
        line_vals=["triton", "torch"],
        styles=[("blue", "-"), ("green", "-")],  # Line styles
        ylabel="GB/s",  # Label name for the y-axis
        plot_name="vector-add-performance",
        args={},
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device="cuda", dtype=torch.float32)
    y = torch.rand(size, device="cuda", dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]

    if provider == "torch":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    elif provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)

    gbps = lambda ms: 12 * size / ms * 1e-6
    return list(map(gbps, [ms, min_ms, max_ms]))

benchmark.run(print_data=True, show_plots=True)


Output:

vector-add-performance:
           size      Triton       Torch
0        4096.0    8.000000    9.600000
1        8192.0   18.177514   16.786886
2       16384.0   31.999999   31.999999
3       32768.0   56.237985   63.503880
4       65536.0   76.800002   76.800002
5      131072.0  109.714284  109.714284
6      262144.0  165.913932  134.111870
7      524288.0  175.542852  175.542852
8     1048576.0  195.047621  198.193551
9     2097152.0  201.442627  203.107443
10    4194304.0  214.637557  215.578943
11    8388608.0  223.418180  222.250103
12   16777216.0  225.366931  223.672355
13   33554432.0  229.682243  228.066990
14   67108864.0  230.828288  224.566540
15  134217728.0  228.096453  224.454365

References


Read More

Leave a ReplyCancel reply

Exit mobile version