Skip to content

OpenAI Triton Note (1): 向量相加

Last Updated on 2024-01-29 by Clay

介紹

Triton 是一套開源的 GPU 程式語言編譯器,由 OpenAI 於 2021 年發佈,近年來有越來越多的開發使用 Triton 來編寫與優化在 GPU 上的併行程式。相較傳統 CUDA/OpenCL 等函式庫,Triton 提供了一種 Python-like 語法,顯得更清晰與容易上手。

Triton 有以下特色:

  • 簡化記憶體訪問:Triton 抽象出了一層高效的記憶體訪問界面,允許開發者通過高級的操作來加載/儲存資料,而需要直接進行複雜的記憶體管理
  • 自動性能優化:通過自動調整內核的執行與資料吞吐量,進而幫助減少手動優化
  • Python 整合:Triton 是一個 Python 的函式庫,可以直接與 Python 程式碼整合在一起,使得優化 GPU 加速的過程變得更容易

而以下,我會紀錄如何進行 Triton 官方文件中的第一個教學、並把教學中我不懂但又去找來的一些資料整理在一起,希望能幫助所有想學習 Triton 的人、以及未來回來複習的我。


安裝

會來學習 Triton 的我想都是進階開發者了,所以就不多贅述,請大家盡量在 Python 虛擬環境中安裝套件,避免影響原生系統。

pip install triton


如果想要安裝最新的預覽版本、或是想自行從來源安裝,請參考文末的參考資料。


範例程式: 向量相加

以下會學習:

  • 使用 triton.jit 裝飾器去定義一個基礎的 Triton kernel
  • 驗證與使用 Triton Benchmark,評估 Triton kernel 的性能


以下,我們匯入必要的套件後,開始使用 @triton.jit(jit 為 just-in-time,即時編譯的意思)去定義 Triton kernel。

其中,x_ptry_ptroutput_ptr 都是指標(pointer);在 Triton kernel 中,傳遞的都是指標而非物件。指標可視為對變數的引用,其指向變數在記憶體中的記憶體位址,允許直接訪問與操作。另外,output 是我們需要預先定義好一個空陣列,用於儲存 xy 的相加結果。

pid 是用於獲取當前 program 的 ID,program 是 Triton 的最基本執行單位。在這裡使用了 axis=0 代表使用了一維的啟動網格launch grid)。

block_start = pid * BLOCK_SIZE 是計算當前的程式所處理的資料 block 起始點。

mask = offsets < n_elements 是建立一個掩碼,用來防止訪問超出向量尾端的元素。這是因為我們所處理的資料並不一定總是被 block 的尺寸剛好整除。

x = tl.load(x_ptr + offsets, mask=mask)y = tl.load(y_ptr + offsets, mask=mask) 都是使用掩碼 mask 去加載向量。

tl.store(output_ptr + offsets, output, mask=mask) 是把 xy 相加的 output 儲存回由 output_ptr 指定的記憶體位址,使用 offsetmask 來確保只更新有效的元素。

import torch
import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
):
    # There are multiple 'programs' processing different data. We identify which program
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.

    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # Would each access the elements [0:64, 64:128, 128:192, 192:256].
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements

    # Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y

    # Write x + y back to DRAM
    tl.store(output_ptr + offsets, output, mask=mask)


定義完成相加兩向量的 Triton kernel 後,現在我們馬上來建立調用 add_kernel() 這個 Triton 核心的函式 add()

def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )

    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output


這段程式碼最讓人糊塗的就是為什麼要建立 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) 這個匿名函式。這是因為之後的 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 實際上調用了之前定義的 add_kernel,而它又需要指定一個函式來配置 kernel 本身需要的啟動參數 —— 這就是為什麼我們需要一個 grid 函式,並且其網格大小基於每個 block 應該處理的元素量,所以才會使用 triton.cdiv() 去計算一個 block 要處理多少元素(計算結果一定是整數,比方說 triton.cdiv(100, 3),有 100 個元素並且有 3 個 blocks,其計算結果一定是 34 而非 33.333333)。


來看看結果跟 PyTorch 自身的相加是否一致。

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device="cuda")
y = torch.rand(size, device="cuda")
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')


Output:

tensor([1.3713, 1.3076, 0.4940,  ..., 0.9592, 0.3409, 1.2567], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.9592, 0.3409, 1.2567], device='cuda:0')
The maximum difference between torch and triton is 0.0

完全沒有誤差!

就這樣,我們第一個簡單的 Triton kernel 寫完了,真的是十分貼齊 Python 開發者的偏好,簡單好懂。


最後,也一併來進行 Triton 的 benchmark 測試,Triton 整合了視覺化的線圖和表格型的文字顯示;這個方法以後會非常常用,也會跟其他的優化方法比較。

但今天,只有 Triton 和原生 PyTorch。

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["size"],
        x_vals=[2**i for i in range(12, 28)],
        x_log=True,  # x axis is logarithmic
        line_arg="provider",
        line_names=["Triton", "Torch"],
        line_vals=["triton", "torch"],
        styles=[("blue", "-"), ("green", "-")],  # Line styles
        ylabel="GB/s",  # Label name for the y-axis
        plot_name="vector-add-performance",
        args={},
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device="cuda", dtype=torch.float32)
    y = torch.rand(size, device="cuda", dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]

    if provider == "torch":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    elif provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)

    gbps = lambda ms: 12 * size / ms * 1e-6
    return list(map(gbps, [ms, min_ms, max_ms]))

benchmark.run(print_data=True, show_plots=True)


Output:

vector-add-performance:
           size      Triton       Torch
0        4096.0    8.000000    9.600000
1        8192.0   18.177514   16.786886
2       16384.0   31.999999   31.999999
3       32768.0   56.237985   63.503880
4       65536.0   76.800002   76.800002
5      131072.0  109.714284  109.714284
6      262144.0  165.913932  134.111870
7      524288.0  175.542852  175.542852
8     1048576.0  195.047621  198.193551
9     2097152.0  201.442627  203.107443
10    4194304.0  214.637557  215.578943
11    8388608.0  223.418180  222.250103
12   16777216.0  225.366931  223.672355
13   33554432.0  229.682243  228.066990
14   67108864.0  230.828288  224.566540
15  134217728.0  228.096453  224.454365

References


Read More

Leave a Reply