Last Updated on 2024-12-03 by Clay
Introduction
Recently, while implementing the paper Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting, I encountered a question about its use of Cross Entropy Loss to align the probability distributions of the draft model and the target model. Why not use KL Divergence instead?
To put it simply, both Cross Entropy and KL Divergence are used to measure the relationship between two distributions. Cross Entropy assesses the similarity, while KL Divergence quantifies the distance.
Here, we define p as the true distribution and q as the predicted distribution. From a deep learning perspective, we aim for the predicted distribution q to closely resemble the true distribution p.
Mathematical Definitions
Cross-Entropy
KL Divergence
In essence, Cross Entropy can be seen as KL Divergence plus the entropy of the true distribution
Why are they equivalent in this case? This is because the additional term
Therefore:
In this situation, KL Divergence and Cross Entropy are not merely similar—they are identical!
Differences Between Cross Entropy and KL Divergence
Now we arrive at the main topic: What are the differences between Cross Entropy and KL Divergence?
As mentioned earlier, they are equivalent for hard labels. However, for soft labels, the differences become apparent.
In Cross Entropy, the loss for a specific category x is calculated as
KL Divergence, on the other hand, computes the loss as
Here’s an example:
import torch
# Define the distributions P and Q
P = torch.tensor([0.98, 0.01, 0.01]) # Target distribution
Q = torch.tensor([0.979, 0.02, 0.001]) # Model predicted distribution
# Compute the contributions of each class to Cross Entropy
cross_entropy_contributions = -P * torch.log(Q)
total_cross_entropy = torch.sum(cross_entropy_contributions)
cross_entropy_ratios = cross_entropy_contributions / total_cross_entropy
# Compute the contributions of each class to KL Divergence
kl_divergence_contributions = P * (torch.log(P) - torch.log(Q))
total_kl_divergence = torch.sum(kl_divergence_contributions)
# Calculate the absolute proportion of each class's contribution to KL Divergence
kl_divergence_absolute_ratios = torch.abs(kl_divergence_contributions) / torch.sum(torch.abs(kl_divergence_contributions))
# Print the results for Cross Entropy contributions
print("Cross Entropy Contributions:")
for i, contrib in enumerate(cross_entropy_contributions):
print(f"Class {i}: {contrib.item()} (Proportion: {cross_entropy_ratios[i].item():.2%})")
# Print the results for KL Divergence contributions
print("\nKL Divergence Contributions:")
for i, contrib in enumerate(kl_divergence_contributions):
print(f"Class {i}: {contrib.item()} (Absolute Proportion: {kl_divergence_absolute_ratios[i].item():.2%})")
Output:
Cross Entropy Contributions: Class 0: 0.020799191668629646 (Proportion: 16.12%) Class 1: 0.039120230823755264 (Proportion: 30.33%) Class 2: 0.06907755136489868 (Proportion: 53.55%) KL Divergence Contributions: Class 0: 0.0010005577933043242 (Absolute Proportion: 3.23%) Class 1: -0.006931471638381481 (Absolute Proportion: 22.39%) Class 2: 0.02302584983408451 (Absolute Proportion: 74.38%)
From the results, we can observe that Cross Entropy focuses more on high-probability categories like class 0. When the predicted and true distributions differ by only 0.001, Cross Entropy contributes 16.12% to the loss, whereas KL Divergence contributes only 3.23%. Conversely, in class 2, where the predicted distribution differs significantly, KL Divergence contributes much more to the loss than Cross Entropy, with a 20% difference.
This demonstrates that KL Divergence emphasizes the overall distribution shape, while Cross Entropy prioritizes categories with higher true probabilities.
Extended Discussion
When we use Cross Entropy and KL Divergence as loss functions (rather than metrics for measuring the difference between information systems), they can be considered equivalent.
This is because KL Divergence and Cross Entropy differ only by the entropy H(p)H(p)H(p) of the target distribution. When the target distribution represents the ground truth, this term becomes a constant.
The relationship can be expressed as:
Where entropy is defined as:
Thus, when differentiating, the gradients obtained from using KL Divergence or Cross Entropy are the same.
For further discussion, you can refer to this thread on AI Stack Exchange:
Why has the cross-entropy become the classification standard loss function and not KL divergence?
Conclusion
Now, I think I understand why the paper I’m implementing and experimenting with uses Cross Entropy instead of KL Divergence. While I intuitively prefer aligning the probability distributions of the draft and target models, the authors might prioritize higher-probability categories (to ensure the draft model generates tokens acceptable to the target model during decoding).
I plan to conduct further experiments to verify this. In my preliminary results, KL Divergence outperformed Cross Entropy by two orders of magnitude in acceptance rates—although this might be because I used a high sampling temperature, favoring KL Divergence.
References
- Wikipedia - Kullback–Leibler divergence
- KL Divergence vs Cross Entropy: Exploring the Differences ...