Last Updated on 2025-02-24 by Clay
之前有在閱讀模型架構原始碼時,有嘗試寫過 LayerNorm 的實現([Machine Learning] LayerNorm 層歸一化筆記),但當時的實作也僅只於按照公式復現;最近在複習架構設計時,對於 LayerNorm 有了更深入的思考,故筆記於此。
LayerNorm (Layer Normalization, 層歸一化)的主要目的是控制模型的 hidden states 範圍,穩定神經網路的學習過程。
簡單來說,LayerNorm 對於每一層的輸入向量進行標準化,讓其值分佈穩定;這樣做甚至會提昇訓練的收斂速度和模型的穩定性。
在 Transformer 架構中,每一層神經網路輸出的 hidden states 會因為線性變換與非線性激活而不斷變大或縮小。如果輸出的數值範圍過大或過小,會引起梯度爆炸或梯度消失。
所以更明白地說,LayerNorm 的作用就是將 hidden states 調整為『均值為 0、標準差為 1』的範圍,防止梯度紊亂。而之所以說能加速訓練,是因為 hidden states 維持在穩定範圍內,更容易讓優化器(Adam、SGD)找到最小值。
對於一組 hidden states 向量
: 該層的均值 : 該層的方差 : 極小的常數,用來防止分母為零
而為了保留模型的表現力,LayerNorm 再加入了可學習參數
: 縮放係數(learnable scale parameter) : 平移係數(learnable shift parameter)
補充說明
- LayerNorm 經常會與另外一種正規化方法 BatchNorm 比較。不過 BatchNorm 比較常用於 CV 模型,並且是依賴 batch size 去縮放,所以在小 batch size 時效果會浮動。
- LayerNorm 更常用於 NLP 和序列處理模型,不依賴 batch 中的不同資料特徵從而保證 inference 時效果穩定。
- LayerNorm 對於 GPU 的記憶體訪問沒有那麼高效,因為它按照每個樣本進行標準化。
- RMSNorm 是 LayerNorm 的變體,使用 均方根 (RMS) 進行標準化,並保留 可學習的縮放參數
和平移參數 ,在 LLM 中常被採用,計算更高效且效果接近 LayerNorm。