Skip to content

KL Divergence 和 Cross Entropy 的差異與比較

Last Updated on 2024-12-03 by Clay

前言

最近在實作論文 Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting 時,對於其所採用 Cross Entropy Loss 來讓 draft model 和 target model 的機率分佈輸出越像越好這件事,產生了一個疑惑:為什麼不使用 KL Divergence 呢?

通俗來說,我們知道 Cross Entropy 和 KL Divergence 都是用於衡量兩個分佈;Cross Entropy 用來衡量兩個分佈 p 和 q 之間的相似程度、KL Divergence 用來衡量兩個分佈 p 和 q 之間的距離。

其中,p 我們定義為真實分佈,q為預測分佈。一般以深度學習的角度來說,我們希望預測分佈 q 能夠跟真實分佈 p 越像越好。


數學定義

Cross-Entropy(交叉熵)

H(p,q)=-\sum_{x}p(x)log(q(x))

KL Divergence(KL 散度)

H(p)=-\sum_{x}p(x)log(p(x)),

D_{KL}(p\parallel q)=\sum_{x}p(x)log(\frac{p(x)}{q(x)})=H(p,q)-H(p)

也就是說,我們可以想像 Cross-Entropy 其實比 KL Divergence 多包含一個真實分佈 p 的熵值 H(p)。不過,這也導致在我們的真實分佈 p 為 one-hot encoding 時(只有一個類別的機率為 1,其餘類別為 0 ),Cross-Entropy 和 KL Divergence 是完全等價的。

為什麼是完全等價的呢?這是因為本來兩者之間差異的 $H(x)$ ,是目標分佈(真實分佈)的熵,可以寫為:

H(x)=p(x)log(p(x)),只有一個類別是 $p(x)=1$ 其餘皆為 $0$,但是在那個類別中的 log(p(x))=log(1)=0,所以熵 $H(P)=0$ 是確定的。

這樣一來:

D_{KL}(p\parallel q)=\sum_{x}p(x)log(\frac{p(x)}{q(x)})=H(p,q)-H(p)=H(p,q)-0=H(p,q)

此時 KL Divergence 跟 Cross-Entropy 不能說很像,只能說是一模一樣!


Cross-Entropy 和 KL Divergence 的差異

接下來終於要進入正題了:Cross-Entropy 和 KL Divergence 的差異在哪呢?

其實我們剛剛已經提過了,在 hard label 時是等價的,那麼在 soft label 時就不同了。

對於 Cross-Entropy 來說,對於某一個特定類別 x 都會是使用 -p(x)log(q(x)) 來計算損失,所以對 p(x) 極小的類別來說影響相對較小 —— 換句話說就是不太在意這種極小值的分佈形狀。

而 KL Divergence 就不同了,每個類別的貢獻都會按照 p(x)log(\frac{p(x)}{q(x)}) 去計算,也就是同時需要滿足 log(\frac{p(x)}{q(x)}) 很小的情況(p(x) 和 q(x) 接近),其距離才會真的很小。

舉個實際的案例來看:

import torch

# Define the distributions P and Q
P = torch.tensor([0.98, 0.01, 0.01])  # Target distribution
Q = torch.tensor([0.979, 0.02, 0.001])  # Model predicted distribution

# Compute the contributions of each class to Cross Entropy
cross_entropy_contributions = -P * torch.log(Q)
total_cross_entropy = torch.sum(cross_entropy_contributions)
cross_entropy_ratios = cross_entropy_contributions / total_cross_entropy

# Compute the contributions of each class to KL Divergence
kl_divergence_contributions = P * (torch.log(P) - torch.log(Q))
total_kl_divergence = torch.sum(kl_divergence_contributions)

# Calculate the absolute proportion of each class's contribution to KL Divergence
kl_divergence_absolute_ratios = torch.abs(kl_divergence_contributions) / torch.sum(torch.abs(kl_divergence_contributions))

# Print the results for Cross Entropy contributions
print("Cross Entropy Contributions:")
for i, contrib in enumerate(cross_entropy_contributions):
    print(f"Class {i}: {contrib.item()} (Proportion: {cross_entropy_ratios[i].item():.2%})")

# Print the results for KL Divergence contributions
print("\nKL Divergence Contributions:")
for i, contrib in enumerate(kl_divergence_contributions):
    print(f"Class {i}: {contrib.item()} (Absolute Proportion: {kl_divergence_absolute_ratios[i].item():.2%})")


Output:

Cross Entropy Contributions:
Class 0: 0.020799191668629646 (Proportion: 16.12%)
Class 1: 0.039120230823755264 (Proportion: 30.33%)
Class 2: 0.06907755136489868 (Proportion: 53.55%)

KL Divergence Contributions:
Class 0: 0.0010005577933043242 (Absolute Proportion: 3.23%)
Class 1: -0.006931471638381481 (Absolute Proportion: 22.39%)
Class 2: 0.02302584983408451 (Absolute Proportion: 74.38%)

我們可以看到,在 Cross-Entropy 中最關心的還是大機率的類別,也就是 class 0,在我們的預測分佈跟真實分佈只差 0.001 時,Cross-Entropy 的 loss 就貢獻了 16.12%,相比之下 KL Divergence 只貢獻了 3.23%;不過反來說,在 class 2 的情況我們的預測分佈與真實分佈差了一個數量級,但是 在 KL Divergence 貢獻的 loss 就硬是比 Cross Entropy 多了 20%。

這確實說明了 KL Divergence 其實看重整體分佈的形狀(每一個類別),而 Cross Entropy 則更關注真實分佈中容易出現的類別。


延伸討論

不過在我們拿 Cross Entropy 和 KL Divergence 來當作損失函數(loss function)而非拿來衡量資訊系統的差異時,可以視為等價的

這是因為 KL Divergence 和 Cross Entropy 只差一個目標分佈的熵 $H(p)$,並且在目標分佈為 ground truth 的情況下就是一個妥妥的常數項。

D_{KL}(p\parallel q)=H(p,q)-H(p)

H(x)=p(x)log(p(x))

所以在微分後,使用 KL Divergence 和 Cross Entropy 所得到的梯度應該是相同的。

相關的討論也可以參考 https://ai.stackexchange.com/questions/3065/why-has-the-cross-entropy-become-the-classification-standard-loss-function-and-n 這個討論。


總結

現在的話,我似乎能理解我正在實作、實驗的那篇論文為什麼是採用 Cross Entropy 而非 KL Divergence 了。我很直覺地希望 draft model 和 target model 之間的輸出機率分佈應該相似且一致,但是搞不好原作者團隊其實更看重高機率的類別(因為解碼時需要讓 draft model 產生出來的 token 被 target model 所接受)。

我決定再多做一點實驗來確認,因為在我的初步的實驗結果中,採用 KL Divergence 的接受率比 Cross-Entropy 高兩個數量級 —— 當然,也有一種可能是因為我採樣的溫度太高了,導致 KL Divergence 效果就是好。


References


Read More

Leave a Reply