Skip to content

[Machine Learning] KL 散度(Kullback-Leibler Divergence)筆記

Last Updated on 2023-07-19 by Clay

什麼是 KL 散度(Divergence)?

我們經常會在機器學習中聽到 『KL 散度』(Kullback Libler Divergence)這個詞,而 KL 散度其實就是評估兩個機率分佈 P 和 Q 之間『差異性』的一個評估值。

KL 散度有許多不同的別名,在每個領域可能都有屬於自己的名字,比如說相對熵、資訊增益、資訊散度...... 等等稱呼。但本質上是一樣的東西。

對於離散機率分佈,KL 散度的公式如下:

對於連續機率分布,KL 散度的公式如下:

其中,P(i) 跟 Q(i) 都是機率分佈在 i 點上的機率,log 則是自然對數。

KL 散度有以下幾個特性:

  • 非負性:即使在 P 等於 Q 的情況,D(P||Q) = 0
  • 非對稱性:D(P||Q) 不等於 D(Q||P),一個是 P 為基準、一個是以 Q 為基準
  • 資訊增益(information gain):如果一段資訊(先不管是什麼資訊)我們使用 Q 來編碼,那麼換成 P 來編碼時,我們可以平均減少多少 bits 的資訊,這就是 information gain


講了這麼多,可能又有點抽象,我來舉一個實際應用的例子好了(我碩班時寫作業還寫過這東西... 但那應該稱不上實際應用):

不知道大家是否聽過 PPO 呢?有陣子有許多人在討論 ChatGPT 是否是使用 PPO 的方式訓練而成的,但是結論了我忘記了。

PPO 的全名是 Proximal Policy Optimization,是一種在深度強化學習中的策略優化方法。這種方法的主要目標是解決策略梯度方法中一個常見的問題:在更新策略的時候,可能會使得新的策略與原來的策略差距過大,進而導致學習的不穩定。

在 PPO 中,這種問題是通過添加一個額外的損失項來解決的,這個損失項衡量的就是新策略與舊策略的 KL 散度。通過最小化這個 KL 散度,我們可以確保在每一步更新中,新策略不會與舊策略差距過大,從而提高學習的穩定性。

可以想像成,原本的模型 A 已經很會做 X 任務了,但現在我們要讓 A 模型繼續學習 Y 任務 —— 然而,如果我們什麼額外的資訊都不約束,最終我們訓練在 Y 任務上的 A' 模型就會變得不太像 A 模型了。

所以我們在訓練 A' 模型時,我們時時刻刻將 A' 模型的輸出與 A 模型計算 KL 散度,試圖在最小化 KL 散度的同時又學習下一個任務 Y。如此一來我們就不會讓 A' 模型距離 A 太遙遠。

不知道這樣一來,有沒有比較感受到 KL 散度的光輝了呢?


程式實作

如果我們想要實際寫寫看計算 KL 散度的程式,可以參考以下 Python 的寫法:

import numpy as np
from scipy.stats import norm
from scipy.special import kl_div

# Assume we have two normal distributions: N(0, 1) and N(1, 2)
mu1, sigma1 = 0, 1
mu2, sigma2 = 1, 2

# Generate an array of values between -10 and 10
x = np.linspace(-10, 10, 1000)

# Calculate the probability density of the two distributions at these points
p = norm.pdf(x, mu1, sigma1)
q = norm.pdf(x, mu2, sigma2)

# Use the kl_div function from scipy to calculate KL divergence
kl_divergence = kl_div(p, q)

# Since KL divergence is a sum of the results for each point, we need to add them up
kl_divergence = np.sum(kl_divergence)

print(kl_divergence)


Output:

22.135034999282027

上方的 Python 程式碼計算了兩個常態分佈之間的 KL 散度,可以看到兩個分佈之間的差異相當地大。

在這裡使用到了 scipy 函式庫中的 kl_div 函式,他可以直接計算兩個離散機率分佈之間的 KL 散度,但是這只是一個近似值,因為真正的 KL 散度需要對整個機率分佈做積分。


References


Read More

Leave a Reply取消回覆

Exit mobile version