Skip to content

OpenAI Triton Note (2): Fused Softmax

Last Updated on 2024-01-29 by Clay

介紹

Softmax 是一個常見的激活函數activation function),也經常被用作多分類的最後一層。

\sigma=softmax
\vec{z}=input vector
e^{z_{i}}=standard exponential function for input vector
K=number of classes in the multi-class classifier
e^{z_{j}}=standard exponential function for output vector
e^{z_{j}}=standard exponential function for output vector

通過 softmax 函數,我們可以將多分類的輸出轉換成一組介於 (0, 1) 之間並且加總為 1 的機率分佈。

而在 Triton 中,今天則介紹了將多個操作融合在一起的 softmax 函數,並能在一定性能的 GPU 上明顯地發現比原生的 PyTorch 來得速度更快。(之所以說要一定性能,是因為若使用我筆電的 4060 laptop 版本看不出差異,但是在 3090 的 server 上能很清晰地分出高下

根據前一篇內容(請參考文末 Read More),我們可以使用 triton.jit 去撰寫 Triton kernel,並使用 Triton 的 benchmark 進行比較 —— 本篇 fused softmax 也不例外要使用這些工具。


Triton 如何撰寫 softmax

在開始前,我們先匯入必要的套件。

import torch
import triton
import triton.language as tl


接著我們定義一個 torch.jit.script 的函式。

@torch.jit.script
def naive_softmax(x):
    x_max = x.max(dim=1)[0]
    z = x - x_max[:, None]
    
    numerator = torch.exp(z)
    denominator = numerator.sum(dim=1)
    ret = numerator / denominator[:, None]
    
    return ret


這是 PyTorch 原生計算 softmax 函式的方法,並且嘗試使用了 PyTorch 的即時編譯(Just-In-Time Compilation, JIT)優化;當然,在真實訓練腳本中,仍然推薦使用官方實現的 torch.softmax() 函式。

這是我們的其中一個比較對象。

那接下來,我們嘗試撰寫 Triton 的 kernel。

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride

    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets

    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
    row_minus_max = row - tl.max(row, axis=0)

    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)


我們使用 @triton.jit 告訴 Triton 編譯器這個函式需要被 JIT 編譯成 GPU 上的核心。

接著我們使用 row_idx = tl.program_id(0) 去取得核心的唯一 ID,並計算輸入矩陣中對應當前 row 的起始指標位置。

col_offsets = tl.arange(0, BLOCK_SIZE) 建立一個序列,從 0 到 BLOCK_SIZE,表示 row 的偏移量。

input_ptrs = row_start_ptr + col_offsetsrow = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")) 則是用剛才建立的 row 偏移量建立一個指標陣列,指向 row 中的每個元素,同時將其載入並使用掩碼避免超出應存取的記憶體位址。

下面 numeratordenominator 就可視為常規的 softmax 函數計算了。最後使用 tl.store() 將結果儲存回預留的 output_ptrs 中。


我們可以比較一下在 Triton 的 benchmark 上,PyTorch(JIT)、PyTorch Softmax 以及 Triton kernel 哪個性能最好:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["N"],
        x_vals=[128*i for i in range(2, 100)],
        line_arg="provider",
        line_vals=[
            "triton",
            "torch-native",
            "torch-jit",
        ],
        line_names=[
            "Triton",
            "Torch (native)",
            "Torch (jit)",
        ],
        styles=[
            ("blue", "-"),
            ("green", "-"),
            ("green", "--"),
        ],
        ylabel="GB/s",
        plot_name="softmax-performance",
        args={"M": 4096},
    )
)
def benchmark(M, N, provider):
    x = torch.randn(M, N, device="cuda", dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == "torch-native":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
    if provider == "torch-jit":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)

    gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(min_ms), gbps(max_ms)


benchmark.run(show_plots=True, print_data=True)


Output:

softmax-performance:
          N      Triton  Torch (native)  Torch (jit)
0     256.0  630.153853      682.666643   234.057137
1     384.0  682.666643      722.823517   229.682243
2     512.0  712.347810      712.347810   240.941172
3     640.0  758.518517      731.428561   238.139535
4     768.0  768.000002      744.727267   240.941181
..      ...         ...             ...          ...
93  12160.0  833.903023      436.233193   282.822809
94  12288.0  833.084721      436.421757   283.038742
95  12416.0  834.689099      436.606592   282.382373
96  12544.0  834.528074      436.313034   282.681693
97  12672.0  832.924315      436.025816   282.778247

[98 rows x 4 columns]


可以發現,Triton kernel 果然是最快的!

目前為止,我差不多摸清楚了 Triton Kernel 的基本寫法,但是對如何應用在 PyTorch 的訓練、推理腳本中還有些迷茫 —— 到底該如何跟我的 PyTorch 整合呢?

另外,到底該怎麼寫、該怎麼分配 BLOCK_SIZE 才是最好的?還有哪些地方可以調整呢?

感覺還有許多課題還沒有真正領悟,我想我會繼續把這份筆記紀錄下去。


References


Read More

Leave a Reply