Last Updated on 2024-01-29 by Clay
介紹
Triton 是一套開源的 GPU 程式語言編譯器,由 OpenAI 於 2021 年發佈,近年來有越來越多的開發使用 Triton 來編寫與優化在 GPU 上的併行程式。相較傳統 CUDA/OpenCL 等函式庫,Triton 提供了一種 Python-like 語法,顯得更清晰與容易上手。
Triton 有以下特色:
- 簡化記憶體訪問:Triton 抽象出了一層高效的記憶體訪問界面,允許開發者通過高級的操作來加載/儲存資料,而需要直接進行複雜的記憶體管理
- 自動性能優化:通過自動調整內核的執行與資料吞吐量,進而幫助減少手動優化
- Python 整合:Triton 是一個 Python 的函式庫,可以直接與 Python 程式碼整合在一起,使得優化 GPU 加速的過程變得更容易
而以下,我會紀錄如何進行 Triton 官方文件中的第一個教學、並把教學中我不懂但又去找來的一些資料整理在一起,希望能幫助所有想學習 Triton 的人、以及未來回來複習的我。
安裝
會來學習 Triton 的我想都是進階開發者了,所以就不多贅述,請大家盡量在 Python 虛擬環境中安裝套件,避免影響原生系統。
pip install triton
如果想要安裝最新的預覽版本、或是想自行從來源安裝,請參考文末的參考資料。
範例程式: 向量相加
以下會學習:
- 使用 triton.jit 裝飾器去定義一個基礎的 Triton kernel
- 驗證與使用 Triton Benchmark,評估 Triton kernel 的性能
以下,我們匯入必要的套件後,開始使用 @triton.jit
(jit 為 just-in-time,即時編譯的意思)去定義 Triton kernel。
其中,x_ptr
、y_ptr
、output_ptr
都是指標(pointer);在 Triton kernel 中,傳遞的都是指標而非物件。指標可視為對變數的引用,其指向變數在記憶體中的記憶體位址,允許直接訪問與操作。另外,output
是我們需要預先定義好一個空陣列,用於儲存 x
和 y
的相加結果。
pid 是用於獲取當前 program 的 ID,program 是 Triton 的最基本執行單位。在這裡使用了 axis=0
代表使用了一維的啟動網格(launch grid)。
block_start = pid * BLOCK_SIZE
是計算當前的程式所處理的資料 block 起始點。
mask = offsets < n_elements
是建立一個掩碼,用來防止訪問超出向量尾端的元素。這是因為我們所處理的資料並不一定總是被 block 的尺寸剛好整除。
x = tl.load(x_ptr + offsets, mask=mask)
和 y = tl.load(y_ptr + offsets, mask=mask)
都是使用掩碼 mask
去加載向量。
tl.store(output_ptr + offsets, output, mask=mask)
是把 x
和 y
相加的 output
儲存回由 output_ptr
指定的記憶體位址,使用 offset
和 mask
來確保只更新有效的元素。
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
):
# There are multiple 'programs' processing different data. We identify which program
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# Would each access the elements [0:64, 64:128, 128:192, 192:256].
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
定義完成相加兩向量的 Triton kernel 後,現在我們馬上來建立調用 add_kernel()
這個 Triton 核心的函式 add()
:
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# In this case, we use a 1D grid where the size is the number of blocks
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
這段程式碼最讓人糊塗的就是為什麼要建立 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
這個匿名函式。這是因為之後的 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
實際上調用了之前定義的 add_kernel
,而它又需要指定一個函式來配置 kernel 本身需要的啟動參數 —— 這就是為什麼我們需要一個 grid
函式,並且其網格大小基於每個 block 應該處理的元素量,所以才會使用 triton.cdiv()
去計算一個 block 要處理多少元素(計算結果一定是整數,比方說 triton.cdiv(100, 3)
,有 100 個元素並且有 3 個 blocks,其計算結果一定是 34 而非 33.333333)。
來看看結果跟 PyTorch 自身的相加是否一致。
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device="cuda")
y = torch.rand(size, device="cuda")
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
Output:
tensor([1.3713, 1.3076, 0.4940, ..., 0.9592, 0.3409, 1.2567], device='cuda:0') tensor([1.3713, 1.3076, 0.4940, ..., 0.9592, 0.3409, 1.2567], device='cuda:0') The maximum difference between torch and triton is 0.0
完全沒有誤差!
就這樣,我們第一個簡單的 Triton kernel 寫完了,真的是十分貼齊 Python 開發者的偏好,簡單好懂。
最後,也一併來進行 Triton 的 benchmark 測試,Triton 整合了視覺化的線圖和表格型的文字顯示;這個方法以後會非常常用,也會跟其他的優化方法比較。
但今天,只有 Triton 和原生 PyTorch。
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["size"],
x_vals=[2**i for i in range(12, 28)],
x_log=True, # x axis is logarithmic
line_arg="provider",
line_names=["Triton", "Torch"],
line_vals=["triton", "torch"],
styles=[("blue", "-"), ("green", "-")], # Line styles
ylabel="GB/s", # Label name for the y-axis
plot_name="vector-add-performance",
args={},
)
)
def benchmark(size, provider):
x = torch.rand(size, device="cuda", dtype=torch.float32)
y = torch.rand(size, device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
gbps = lambda ms: 12 * size / ms * 1e-6
return list(map(gbps, [ms, min_ms, max_ms]))
benchmark.run(print_data=True, show_plots=True)
Output:
vector-add-performance: size Triton Torch 0 4096.0 8.000000 9.600000 1 8192.0 18.177514 16.786886 2 16384.0 31.999999 31.999999 3 32768.0 56.237985 63.503880 4 65536.0 76.800002 76.800002 5 131072.0 109.714284 109.714284 6 262144.0 165.913932 134.111870 7 524288.0 175.542852 175.542852 8 1048576.0 195.047621 198.193551 9 2097152.0 201.442627 203.107443 10 4194304.0 214.637557 215.578943 11 8388608.0 223.418180 222.250103 12 16777216.0 225.366931 223.672355 13 33554432.0 229.682243 228.066990 14 67108864.0 230.828288 224.566540 15 134217728.0 228.096453 224.454365