Skip to content

Differences and Comparison Between KL Divergence and Cross Entropy

Last Updated on 2024-12-03 by Clay

Introduction

Recently, while implementing the paper Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting, I encountered a question about its use of Cross Entropy Loss to align the probability distributions of the draft model and the target model. Why not use KL Divergence instead?

To put it simply, both Cross Entropy and KL Divergence are used to measure the relationship between two distributions. Cross Entropy assesses the similarity, while KL Divergence quantifies the distance.

Here, we define p as the true distribution and q as the predicted distribution. From a deep learning perspective, we aim for the predicted distribution q to closely resemble the true distribution p.


Mathematical Definitions

Cross-Entropy

H(p,q)=-\sum_{x}p(x)log(q(x))

KL Divergence

H(p)=-\sum_{x}p(x)log(p(x)),

D_{KL}(p\parallel q)=\sum_{x}p(x)log(\frac{p(x)}{q(x)})=H(p,q)-H(p)

In essence, Cross Entropy can be seen as KL Divergence plus the entropy of the true distribution H(p). However, when the true distribution p is one-hot encoded (where one category has a probability of 1 and all others are 0), Cross Entropy and KL Divergence are equivalent.

Why are they equivalent in this case? This is because the additional term H(p), representing the entropy of the target distribution, simplifies as:

H(x)=p(x)log(p(x)). Since only one category has $p(x)=1$ and all others are $0$, for that category, log(p(x))=log(1)=0. Thus, the entropy $H(P)=0$ is certain.

Therefore:

D_{KL}(p\parallel q)=\sum_{x}p(x)log(\frac{p(x)}{q(x)})=H(p,q)-H(p)=H(p,q)-0=H(p,q)

In this situation, KL Divergence and Cross Entropy are not merely similar—they are identical!


Differences Between Cross Entropy and KL Divergence

Now we arrive at the main topic: What are the differences between Cross Entropy and KL Divergence?

As mentioned earlier, they are equivalent for hard labels. However, for soft labels, the differences become apparent.

In Cross Entropy, the loss for a specific category x is calculated as -p(x)log(q(x)). This means categories with extremely small p(x) contribute minimally to the loss—effectively disregarding the shape of these minor distributions.

KL Divergence, on the other hand, computes the loss as p(x)log(\frac{p(x)}{q(x)}), focusing on minimizing the relative difference log(\frac{p(x)}{q(x)}) (i.e., making p(x) and q(x) closer). Thus, it emphasizes the overall shape of the distribution.

Here’s an example:

import torch

# Define the distributions P and Q
P = torch.tensor([0.98, 0.01, 0.01])  # Target distribution
Q = torch.tensor([0.979, 0.02, 0.001])  # Model predicted distribution

# Compute the contributions of each class to Cross Entropy
cross_entropy_contributions = -P * torch.log(Q)
total_cross_entropy = torch.sum(cross_entropy_contributions)
cross_entropy_ratios = cross_entropy_contributions / total_cross_entropy

# Compute the contributions of each class to KL Divergence
kl_divergence_contributions = P * (torch.log(P) - torch.log(Q))
total_kl_divergence = torch.sum(kl_divergence_contributions)

# Calculate the absolute proportion of each class's contribution to KL Divergence
kl_divergence_absolute_ratios = torch.abs(kl_divergence_contributions) / torch.sum(torch.abs(kl_divergence_contributions))

# Print the results for Cross Entropy contributions
print("Cross Entropy Contributions:")
for i, contrib in enumerate(cross_entropy_contributions):
    print(f"Class {i}: {contrib.item()} (Proportion: {cross_entropy_ratios[i].item():.2%})")

# Print the results for KL Divergence contributions
print("\nKL Divergence Contributions:")
for i, contrib in enumerate(kl_divergence_contributions):
    print(f"Class {i}: {contrib.item()} (Absolute Proportion: {kl_divergence_absolute_ratios[i].item():.2%})")


Output:

Cross Entropy Contributions:
Class 0: 0.020799191668629646 (Proportion: 16.12%)
Class 1: 0.039120230823755264 (Proportion: 30.33%)
Class 2: 0.06907755136489868 (Proportion: 53.55%)

KL Divergence Contributions:
Class 0: 0.0010005577933043242 (Absolute Proportion: 3.23%)
Class 1: -0.006931471638381481 (Absolute Proportion: 22.39%)
Class 2: 0.02302584983408451 (Absolute Proportion: 74.38%)

From the results, we can observe that Cross Entropy focuses more on high-probability categories like class 0. When the predicted and true distributions differ by only 0.001, Cross Entropy contributes 16.12% to the loss, whereas KL Divergence contributes only 3.23%. Conversely, in class 2, where the predicted distribution differs significantly, KL Divergence contributes much more to the loss than Cross Entropy, with a 20% difference.

This demonstrates that KL Divergence emphasizes the overall distribution shape, while Cross Entropy prioritizes categories with higher true probabilities.


Extended Discussion

When we use Cross Entropy and KL Divergence as loss functions (rather than metrics for measuring the difference between information systems), they can be considered equivalent.

This is because KL Divergence and Cross Entropy differ only by the entropy H(p)H(p)H(p) of the target distribution. When the target distribution represents the ground truth, this term becomes a constant.

The relationship can be expressed as:

D_{KL}(p\parallel q)=H(p,q)-H(p)

Where entropy is defined as:

H(x)=p(x)log(p(x))

Thus, when differentiating, the gradients obtained from using KL Divergence or Cross Entropy are the same.

For further discussion, you can refer to this thread on AI Stack Exchange:
Why has the cross-entropy become the classification standard loss function and not KL divergence?


Conclusion

Now, I think I understand why the paper I’m implementing and experimenting with uses Cross Entropy instead of KL Divergence. While I intuitively prefer aligning the probability distributions of the draft and target models, the authors might prioritize higher-probability categories (to ensure the draft model generates tokens acceptable to the target model during decoding).

I plan to conduct further experiments to verify this. In my preliminary results, KL Divergence outperformed Cross Entropy by two orders of magnitude in acceptance rates—although this might be because I used a high sampling temperature, favoring KL Divergence.


References


Read More

Leave a Reply