Last Updated on 2024-01-29 by Clay
介紹
Softmax 是一個常見的激活函數(activation function),也經常被用作多分類的最後一層。
= | softmax | |
= | input vector | |
= | standard exponential function for input vector | |
= | number of classes in the multi-class classifier | |
= | standard exponential function for output vector | |
= | standard exponential function for output vector |
通過 softmax 函數,我們可以將多分類的輸出轉換成一組介於 (0, 1) 之間並且加總為 1 的機率分佈。
而在 Triton 中,今天則介紹了將多個操作融合在一起的 softmax 函數,並能在一定性能的 GPU 上明顯地發現比原生的 PyTorch 來得速度更快。(之所以說要一定性能,是因為若使用我筆電的 4060 laptop 版本看不出差異,但是在 3090 的 server 上能很清晰地分出高下)
根據前一篇內容(請參考文末 Read More),我們可以使用 triton.jit 去撰寫 Triton kernel,並使用 Triton 的 benchmark 進行比較 —— 本篇 fused softmax 也不例外要使用這些工具。
Triton 如何撰寫 softmax
在開始前,我們先匯入必要的套件。
import torch
import triton
import triton.language as tl
接著我們定義一個 torch.jit.script 的函式。
@torch.jit.script
def naive_softmax(x):
x_max = x.max(dim=1)[0]
z = x - x_max[:, None]
numerator = torch.exp(z)
denominator = numerator.sum(dim=1)
ret = numerator / denominator[:, None]
return ret
這是 PyTorch 原生計算 softmax 函式的方法,並且嘗試使用了 PyTorch 的即時編譯(Just-In-Time Compilation, JIT)優化;當然,在真實訓練腳本中,仍然推薦使用官方實現的 torch.softmax()
函式。
這是我們的其中一個比較對象。
那接下來,我們嘗試撰寫 Triton 的 kernel。
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
我們使用 @triton.jit
告訴 Triton 編譯器這個函式需要被 JIT 編譯成 GPU 上的核心。
接著我們使用 row_idx = tl.program_id(0)
去取得核心的唯一 ID,並計算輸入矩陣中對應當前 row 的起始指標位置。
col_offsets = tl.arange(0, BLOCK_SIZE)
建立一個序列,從 0 到 BLOCK_SIZE
,表示 row 的偏移量。
input_ptrs = row_start_ptr + col_offsets
和 row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf"))
則是用剛才建立的 row 偏移量建立一個指標陣列,指向 row 中的每個元素,同時將其載入並使用掩碼避免超出應存取的記憶體位址。
下面 numerator
和 denominator
就可視為常規的 softmax 函數計算了。最後使用 tl.store()
將結果儲存回預留的 output_ptrs
中。
我們可以比較一下在 Triton 的 benchmark 上,PyTorch(JIT)、PyTorch Softmax 以及 Triton kernel 哪個性能最好:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[128*i for i in range(2, 100)],
line_arg="provider",
line_vals=[
"triton",
"torch-native",
"torch-jit",
],
line_names=[
"Triton",
"Torch (native)",
"Torch (jit)",
],
styles=[
("blue", "-"),
("green", "-"),
("green", "--"),
],
ylabel="GB/s",
plot_name="softmax-performance",
args={"M": 4096},
)
)
def benchmark(M, N, provider):
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-native":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles)
if provider == "torch-jit":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles)
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(min_ms), gbps(max_ms)
benchmark.run(show_plots=True, print_data=True)
Output:
softmax-performance: N Triton Torch (native) Torch (jit) 0 256.0 630.153853 682.666643 234.057137 1 384.0 682.666643 722.823517 229.682243 2 512.0 712.347810 712.347810 240.941172 3 640.0 758.518517 731.428561 238.139535 4 768.0 768.000002 744.727267 240.941181 .. ... ... ... ... 93 12160.0 833.903023 436.233193 282.822809 94 12288.0 833.084721 436.421757 283.038742 95 12416.0 834.689099 436.606592 282.382373 96 12544.0 834.528074 436.313034 282.681693 97 12672.0 832.924315 436.025816 282.778247 [98 rows x 4 columns]
可以發現,Triton kernel 果然是最快的!
目前為止,我差不多摸清楚了 Triton Kernel 的基本寫法,但是對如何應用在 PyTorch 的訓練、推理腳本中還有些迷茫 —— 到底該如何跟我的 PyTorch 整合呢?
另外,到底該怎麼寫、該怎麼分配 BLOCK_SIZE
才是最好的?還有哪些地方可以調整呢?
感覺還有許多課題還沒有真正領悟,我想我會繼續把這份筆記紀錄下去。