Skip to content

[Machine Learning] Note Of Kullback-Leibler Divergence

Last Updated on 2024-10-13 by Clay

What is KL Divergence?

In machine learning, we often encounter the term KL Divergence (also known as Kullback-Leibler Divergence). KL Divergence is a metric used to evaluate the difference between two probability distributions P and Q.

KL Divergence has many different names across various fields, such as relative entropy, information gain, information divergence, etc., but they all essentially refer to the same concept.

For discrete probability distributions, the formula for KL Divergence is as follows:

For continuous probability distributions, the formula for KL Divergence is as follows:

Here, P(i) and Q(i) are the probabilities of the distributions at point i, and log denotes the natural logarithm.

KL Divergence has the following properties:

  • Non-negativity: D(P||Q) ≥ 0, and D(P||Q) = 0 when P equals Q.
  • Asymmetry: D(P||Q) ≠ D(Q||P); one uses P as the reference, the other uses Q.
  • Information Gain: If we encode a piece of information (regardless of what it is) using Q, switching to P can, on average, reduce the information by a certain number of bits. This is known as information gain.


After all this explanation, it might still seem a bit abstract. Let me provide a practical example (I wrote about this in an assignment during my master's program... but that might not qualify as a practical application):

Have you heard of PPO? There was a time when many people discussed whether ChatGPT was trained using PPO, but I've forgotten the conclusion.

PPO stands for Proximal Policy Optimization, a policy optimization method in deep reinforcement learning. The main goal of this method is to solve a common problem in policy gradient methods: when updating the policy, the new policy might deviate too much from the original one, leading to unstable learning.

In PPO, this issue is addressed by adding an extra loss term that measures the KL Divergence between the new policy and the old one. By minimizing this KL Divergence, we ensure that the new policy doesn't deviate too much from the old policy in each update, thereby improving learning stability.

Imagine that the original model A is already proficient at task X, but now we want model A to continue learning task Y. If we train without any additional constraints, the resulting model A' trained on task Y might become quite different from the original model A.

Therefore, when training model A', we continuously calculate the KL Divergence between the outputs of model A' and model A, attempting to minimize the KL Divergence while learning the new task Y. This way, we prevent model A' from straying too far from model A.

Does this help you appreciate the brilliance of KL Divergence?


Code Implementation

If you'd like to actually write code to compute KL Divergence, you can refer to the following Python implementation:

import numpy as np
from scipy.stats import norm
from scipy.special import kl_div

# Assume we have two normal distributions: N(0, 1) and N(1, 2)
mu1, sigma1 = 0, 1
mu2, sigma2 = 1, 2

# Generate an array of values between -10 and 10
x = np.linspace(-10, 10, 1000)

# Calculate the probability density of the two distributions at these points
p = norm.pdf(x, mu1, sigma1)
q = norm.pdf(x, mu2, sigma2)

# Use the kl_div function from scipy to calculate KL divergence
kl_divergence = kl_div(p, q)

# Since KL divergence is a sum of the results for each point, we need to add them up
kl_divergence = np.sum(kl_divergence)

print(kl_divergence)


Output:

22.135034999282027

The above Python code calculates the KL Divergence between two normal distributions. As you can see, there's a significant difference between the two distributions.

Here, we use the kl_div function from the scipy library, which can directly calculate the KL Divergence between two discrete probability distributions. However, this is only an approximation because the true KL Divergence requires integrating over the entire probability distributions.


References


Read More

Leave a ReplyCancel reply

Exit mobile version