Last Updated on 2024-09-09 by Clay
Introduction
Softmax is a commonly used activation function, and it is often employed as the last layer in multi-class classification.
= | softmax | |
= | input vector | |
= | standard exponential function for input vector | |
= | number of classes in the multi-class classifier | |
= | standard exponential function for output vector | |
= | standard exponential function for output vector |
Through the softmax function, we can convert the output of multi-class classification into a probability distribution between (0, 1) where the total sum is 1.
In Triton, today we introduce a fused softmax function, which merges multiple operations into one, significantly speeding up compared to native PyTorch on certain GPUs. (The reason I mention certain GPUs is because on my laptop with a 4060 version, there’s little difference, but on a server with a 3090, the performance is much clearer.)
Based on the previous article (please refer to the Read More section at the end), we can use triton.jit
to write Triton kernels, and compare them using Triton’s benchmarking tools — this fused softmax is no exception.
How to Write Softmax in Triton
Before we start, let’s import the necessary packages.
import torch
import triton
import triton.language as tl
Next, we define a function using torch.jit.script
.
@torch.jit.script
def naive_softmax(x):
x_max = x.max(dim=1)[0]
z = x - x_max[:, None]
numerator = torch.exp(z)
denominator = numerator.sum(dim=1)
ret = numerator / denominator[:, None]
return ret
This is the PyTorch-native method for calculating the softmax function, attempting to optimize with PyTorch’s Just-In-Time Compilation (JIT). However, in actual training scripts, it’s still recommended to use the official torch.softmax()
function.
This will be one of our comparison targets.
Now, let’s try writing the Triton kernel.
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
We use @triton.jit
to instruct the Triton compiler to JIT compile this function into a GPU kernel.
Then, row_idx = tl.program_id(0)
retrieves the unique ID for the kernel and calculates the starting pointer for the row in the input matrix corresponding to this kernel instance.
col_offsets = tl.arange(0, BLOCK_SIZE)
creates a sequence from 0 to BLOCK_SIZE
, representing offsets within the row.
input_ptrs = row_start_ptr + col_offsets
and row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf"))
use the previously generated offsets to create a pointer array to load elements from the row, applying a mask to prevent out-of-bound memory accesses.
The following lines numerator
and denominator
proceed with the typical softmax calculation. Finally, tl.store()
writes the result back to the reserved output_ptrs
.
We can now compare the performance on Triton’s benchmark for PyTorch (JIT), PyTorch softmax, and the Triton kernel:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[128*i for i in range(2, 100)],
line_arg="provider",
line_vals=[
"triton",
"torch-native",
"torch-jit",
],
line_names=[
"Triton",
"Torch (native)",
"Torch (jit)",
],
styles=[
("blue", "-"),
("green", "-"),
("green", "--"),
],
ylabel="GB/s",
plot_name="softmax-performance",
args={"M": 4096},
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-native":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
if provider == "torch-jit":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(min_ms), gbps(max_ms)
benchmark.run(show_plots=True, print_data=True)
Output:
softmax-performance: N Triton Torch (native) Torch (jit) 0 256.0 630.153853 682.666643 234.057137 1 384.0 682.666643 722.823517 229.682243 2 512.0 712.347810 712.347810 240.941172 3 640.0 758.518517 731.428561 238.139535 4 768.0 768.000002 744.727267 240.941181 .. ... ... ... ... 93 12160.0 833.903023 436.233193 282.822809 94 12288.0 833.084721 436.421757 283.038742 95 12416.0 834.689099 436.606592 282.382373 96 12544.0 834.528074 436.313034 282.681693 97 12672.0 832.924315 436.025816 282.778247 [98 rows x 4 columns]
As we can see, the Triton kernel indeed performs the fastest!
So far, I’ve gotten a rough understanding of the basics of writing Triton kernels, but I’m still a bit confused about how to integrate them into PyTorch training and inference scripts — how should I properly combine them with PyTorch?
Also, how do I determine the best BLOCK_SIZE
to use? Are there any other areas that can be optimized?
It feels like there are still many questions I haven’t fully grasped, so I’ll keep documenting these notes as I go along.