Skip to content

[Machine Learning] Note Of Variational AutoEncoder (VAE)

Last Updated on 2024-08-31 by Clay

Introduction

Variational AutoEncoder (VAE) is an advanced variant of the AutoEncoder (AE). The architecture is similar to the original AutoEncoder, consisting of an encoder and a decoder.

However, VAE has several distinct features compared to the traditional AutoEncoder:

  1. Using latent space vectors instead of deterministic encoding
    The original AutoEncoder has a deterministic encoding process, where the input vector is transformed into a set of latent feature vectors through the encoder's neural network. In contrast, VAE generates two sets of vectors, "mean" and "variance," and pairs the elements of these vectors to create a normally distributed vector.

Please refer to the architecture diagram of VAE and the formulas below.

Here, ⊙ represents element-wise multiplication. Additionally, ϵ is sampled from a standard normal distribution N(0, 1). In the literature, it is often represented using the standard deviation σ, which ensures it remains positive.

  1. Reparameterization trick
    If we directly let the encoder output a normal distribution (with mean and variance) without any input (or only noise as input), we encounter a problem—these randomly generated normal distributions are still random, making it difficult to perform backward propagation to update the model weights, as the gradients would get stuck at the decoder layer.

The reparameterization approach generates the normal distribution sampling from a different perspective: instead of directly generating random samples, the encoder generates means and variances, which are used to create consistent normal distribution samples. The reason for generating multiple samples rather than just one is that we may have means like mean1, mean2, ..., meanN and variances like var1, var2, ..., varN, resulting in N sets of normal distribution samples.

However, if the model directly generates these parameters that make up the normal distribution samples through the neural network, we are essentially performing ordinary neural network outputs, where the randomization of the normal distribution samples hasn’t passed through the neural network. As a result, gradient descent can be used to update the model weights.

Furthermore, since VAE ultimately generates fake data through the normal distribution input fed to the decoder, which then decodes it into fake data, if we must let the encoder generate normal distribution during training, it implies that we need to provide an additional input (possibly noise or an image) to the encoder. Of course, in some situations, we might require this binding capability to generate features.

  1. Introducing KL divergence as part of the loss function
    This is to ensure that the reparameterized output aligns closely with the standard normal distribution. There are several reasons for choosing KL divergence:
  • Mathematical properties: KL divergence is related to the maximization of likelihood, making it a reasonable evaluation metric in probabilistic modeling.
  • Asymmetry: The KL divergence between P and Q differs from the KL divergence between Q and P.
  • Efficient computation: KL divergence can be directly calculated for standard normal distributions.

The other part of the loss function is the reconstruction loss, similar to the original AutoEncoder. We can directly use MSE to calculate the difference between the input and the final output to improve the model.

So, the entire VAE model needs to learn two things simultaneously: the intermediate random normal distribution should resemble the standard normal distribution, and the final output decoded through the encoder and decoder should resemble the input as closely as possible.


VAE Additional Features

Besides the key points mentioned above, VAE also has some other interesting characteristics:

Model Architecture

VAE can be implemented using any deep learning neural network layers, such as fully connected layers, CNNs, and RNNs, offering a high degree of flexibility. However, when writing the loss function, there is a Beta coefficient multiplied by the KL divergence. The larger the Beta coefficient, the more the model focuses on making the intermediate latent space's normal distribution sampling resemble the standard normal distribution. Conversely, the smaller the Beta coefficient, the more important the reconstruction of the original image becomes.

Setting this Beta value is tricky. I had to adjust it multiple times to balance the losses on both sides. During the testing process, I also discovered the Beta-VAE architecture, which allows the Beta parameter to become part of the model's learning process, automating its setting.


Generative Capability

After training, the decoder part of the VAE can be extracted and used alone. We can use samples from the standard normal distribution as input to the decoder, generating new, unseen data.


Latent Space Properties

The latent space of a VAE exhibits good representation capabilities, such as clustering similar data points together in the latent space.


Variants and Developments

There are many variants of VAE, such as Conditional VAE, Beta-VAE, Disentangled VAE, etc. These variants are developed to address specific issues.

Of course, VAE also has its limitations (especially the basic version), such as being difficult to train or generating relatively blurry images.


Code

Below, I introduce the implementation part of the code.

First, the model's architecture:

from typing import Tuple
import torch


class VAE(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        code_dim = 2
        # Encoder
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(784, 128),
            torch.nn.ReLU(),
        )

        # Mean & Variance
        self.mean_layer = torch.nn.Linear(128, code_dim)
        self.log_variance_layer = torch.nn.Linear(128, code_dim)

        # Decoder
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(code_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 784),
            torch.nn.Sigmoid(),
        )

    def reparameterization(self, means: torch.Tensor, log_variances: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_variances)
        eps = torch.randn_like(std)
        return means + eps * std

    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Encoder
        hiddens = self.encoder(inputs)

        # Mean & Variance
        means = self.mean_layer(hiddens)
        log_variances = self.log_variance_layer(hiddens)

        # Get gaussion distribution
        gaussian_distribution_codes = self.reparameterization(means=means, log_variances=log_variances)

        # Decoder
        decoded = self.decoder(gaussian_distribution_codes)

        return means, log_variances, gaussian_distribution_codes, decoded


To train the model, you need to import the model architecture defined in models.py.

# coding: utf-8
import torch
import torchvision
from models import VAE


def main() -> None:
    # Settings
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    epochs = 100
    batch_size = 128
    lr = 2e-3

    # DataLoader
    train_dataset = torchvision.datasets.MNIST(
        root="../../data/MNIST/",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Model
    model = VAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_function = torch.nn.MSELoss()

    # Train
    for epoch in range(1, epochs + 1):
        for inputs, labels in train_dataloader:
            inputs = inputs.view(-1, 784).to(device)

            # Forward
            means, log_vars, codes, decoded = model(inputs)

            # Backward
            optimizer.zero_grad()

            # Reconstruction loss
            reconstruction_loss = loss_function(decoded, inputs)

            # KL divergence loss
            KL_divergence = torch.mean(-0.5 * torch.sum(1 + log_vars - means ** 2 - log_vars.exp(), dim=1), dim=0)

            # Total loss
            beta = 0.001
            loss = reconstruction_loss + beta * KL_divergence
            loss.backward()

            optimizer.step()

        # Show progress
        print(f"[{epoch}/{epochs}] Loss: {loss.item()}")

    # Save
    torch.save(model, "VAE.pth")


if __name__ == "__main__":
    main()


After training, we can visualize the results.

# coding: utf-8
import torch
import torchvision
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from models import VAE


def main() -> None:
    # Load model
    model = torch.load("VAE.pth")
    model.to("cuda:0").eval()
    print(model)

    # DataLoader
    test_dataset = torchvision.datasets.MNIST(
        root="../../data/MNIST/",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

    # Plot
    axis_x = []
    axis_y = []
    answers = []
    with torch.no_grad():
        for data in test_dataloader:
            inputs = data[0].view(-1, 28 * 28).to("cuda:0")
            answers += data[1].tolist()

            means, log_vars, code, outputs = model(inputs)
            axis_x += code[:, 0].tolist()
            axis_y += code[:, 1].tolist()


    # Use a colormap
    colors = [cm.rainbow(i / 9) for i in answers]
    scatter = plt.scatter(axis_x, axis_y, c=colors)
    cbar = plt.colorbar(scatter)
    cbar.set_ticks([i / 9 for i in range(10)])
    cbar.set_ticklabels([str(i) for i in range(10)])
    plt.show()
    plt.savefig("VAE.png")


if __name__ == "__main__":
    main()

Output:


Compared to the original AutoEncoder, it is evident that data with the same labels cluster more closely together in the latent space, demonstrating its clustering properties.

[Machine Learning] AutoEncoder Basic Introduction (with PyTorch Code)

We can also use VAE to draw the most representative images generated by the latent space. This can illustrate the continuity of the latent space, where similar points correspond to similar images.

This VAE implementation was a fun experience. I thought I was already familiar enough with PyTorch and deep learning concepts, especially after completing the implementation of [Paper Reading] The Forward-Forward Algorithm: Some Preliminary Investigation earlier this year. However, it turned out that starting from scratch without referencing others' implementations still led to many stumbling blocks, including making multiple mistakes in the loss function formulas at first.

I recall a classmate from graduate school who could spend a day reading a paper and then another day fully replicating the results of VAE. Looking back, that's quite impressive, and I really should learn more from such people.


References


Read More

Leave a Reply