Skip to content

OpenAI Triton Note (2): Fused Softmax

Last Updated on 2024-09-09 by Clay

Introduction

Softmax is a commonly used activation function, and it is often employed as the last layer in multi-class classification.

\sigma=softmax
\vec{z}=input vector
e^{z_{i}}=standard exponential function for input vector
K=number of classes in the multi-class classifier
e^{z_{j}}=standard exponential function for output vector
e^{z_{j}}=standard exponential function for output vector

Through the softmax function, we can convert the output of multi-class classification into a probability distribution between (0, 1) where the total sum is 1.

In Triton, today we introduce a fused softmax function, which merges multiple operations into one, significantly speeding up compared to native PyTorch on certain GPUs. (The reason I mention certain GPUs is because on my laptop with a 4060 version, there’s little difference, but on a server with a 3090, the performance is much clearer.)

Based on the previous article (please refer to the Read More section at the end), we can use triton.jit to write Triton kernels, and compare them using Triton’s benchmarking tools — this fused softmax is no exception.


How to Write Softmax in Triton

Before we start, let’s import the necessary packages.

import torch
import triton
import triton.language as tl


Next, we define a function using torch.jit.script.

@torch.jit.script
def naive_softmax(x):
    x_max = x.max(dim=1)[0]
    z = x - x_max[:, None]
    
    numerator = torch.exp(z)
    denominator = numerator.sum(dim=1)
    ret = numerator / denominator[:, None]
    
    return ret


This is the PyTorch-native method for calculating the softmax function, attempting to optimize with PyTorch’s Just-In-Time Compilation (JIT). However, in actual training scripts, it’s still recommended to use the official torch.softmax() function.

This will be one of our comparison targets.

Now, let’s try writing the Triton kernel.

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride

    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets

    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
    row_minus_max = row - tl.max(row, axis=0)

    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)


We use @triton.jit to instruct the Triton compiler to JIT compile this function into a GPU kernel.

Then, row_idx = tl.program_id(0) retrieves the unique ID for the kernel and calculates the starting pointer for the row in the input matrix corresponding to this kernel instance.

col_offsets = tl.arange(0, BLOCK_SIZE) creates a sequence from 0 to BLOCK_SIZE, representing offsets within the row.

input_ptrs = row_start_ptr + col_offsets and row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")) use the previously generated offsets to create a pointer array to load elements from the row, applying a mask to prevent out-of-bound memory accesses.

The following lines numerator and denominator proceed with the typical softmax calculation. Finally, tl.store() writes the result back to the reserved output_ptrs.


We can now compare the performance on Triton’s benchmark for PyTorch (JIT), PyTorch softmax, and the Triton kernel:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["N"],
        x_vals=[128*i for i in range(2, 100)],
        line_arg="provider",
        line_vals=[
            "triton",
            "torch-native",
            "torch-jit",
        ],
        line_names=[
            "Triton",
            "Torch (native)",
            "Torch (jit)",
        ],
        styles=[
            ("blue", "-"),
            ("green", "-"),
            ("green", "--"),
        ],
        ylabel="GB/s",
        plot_name="softmax-performance",
        args={"M": 4096},
    )
)
def benchmark(M, N, provider):
    x = torch.randn(M, N, device="cuda", dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == "torch-native":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
    if provider == "torch-jit":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)

    gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(min_ms), gbps(max_ms)


benchmark.run(show_plots=True, print_data=True)


Output:

softmax-performance:
          N      Triton  Torch (native)  Torch (jit)
0     256.0  630.153853      682.666643   234.057137
1     384.0  682.666643      722.823517   229.682243
2     512.0  712.347810      712.347810   240.941172
3     640.0  758.518517      731.428561   238.139535
4     768.0  768.000002      744.727267   240.941181
..      ...         ...             ...          ...
93  12160.0  833.903023      436.233193   282.822809
94  12288.0  833.084721      436.421757   283.038742
95  12416.0  834.689099      436.606592   282.382373
96  12544.0  834.528074      436.313034   282.681693
97  12672.0  832.924315      436.025816   282.778247

[98 rows x 4 columns]


As we can see, the Triton kernel indeed performs the fastest!

So far, I’ve gotten a rough understanding of the basics of writing Triton kernels, but I’m still a bit confused about how to integrate them into PyTorch training and inference scripts — how should I properly combine them with PyTorch?

Also, how do I determine the best BLOCK_SIZE to use? Are there any other areas that can be optimized?

It feels like there are still many questions I haven’t fully grasped, so I’ll keep documenting these notes as I go along.


References


Read More

Leave a Reply