Skip to content

[Machine Learning] Variational AutoEncoder (VAE) 筆記

介紹

變分自動編碼器(Variational AutoEncoder, VAE) 是自動編碼器(AutoEncoder, AE)的進階變體,架構與原本的自動編碼器相似,同樣都是由編碼器(Encoder)和解碼器(Decoder)所組成。

但是 VAE 有幾個與本來的 AutoEncoder 不同的特點:

  1. 使用潛在空間(latent space)向量替換編碼(encode)
    原本的 AutoEncoder 擁有一個確定的編碼過程,將輸入的向量通過編碼器的神經網路轉換成一組潛在特徵向量,而 VAE 則是透過生成『平均值』和『變異數』兩組向量,並讓向量中的元素兩兩配對生成一組常態分佈的向量。

請參考 VAE 的架構圖與下方公式。

其中,⊙ 為元素對元素的乘法。並且 ϵ 是從標準常態分佈 N(0, 1) 抽樣來的。並且在文獻中,通常會看到用平方根的標準差 σ 表示,這也是為了讓保持為正數。

  1. 重參數化(reparameterization trick)
    如果我們直接讓 Encoder 的任務變成不接受任何輸入(或僅有noise為輸入)直接輸出常態分佈(帶有均值和標準差),我們將面臨一個問題 —— 這些隨機產生的常態分佈畢竟是隨機的,我們將難以進行 backward propagation 來更新模型權重,梯度會卡在 decoder 那一層。

而重參數化換了個角度去生成常態分佈採樣:我們透過 encoder 產生的,是均值與方差,以此來生成固定的一組組的常態分佈採樣。之所以是一組組不是一組,是因為我們可能產生的均值有 mean1, mean2, …, meanN,並且方差有 var1, var2, …, varN,所以我們總共有 N 組常態分佈採樣。

但若是讓模型直接透過神經網路生成這些組成常態分佈採樣的參數,我們基本上做的就是普通的神經網路輸出,組成常態分佈採樣的隨機化部分壓根沒經過神經網路,自然就可以做梯度下降來更新模型權重了。

另外,由於我們最後要讓 VAE 生成假資料,是透過常態分佈的輸入給 decoder 並讓 decoder 解碼出假資料。若是我們在訓練時非得讓 encoder 生成常態分佈,即意味著我們需要額外給定一個輸入(可能是噪音或圖片)給 encoder —— 當然在某些情況下我們可能會需要這樣綁定生成特徵的能力。

  1. 引入 KL 散度當作 loss function 的一部分
    這是為了讓我們重參數化的輸出更貼其標準常態分佈。之所以選擇 KL 散度有以下原因:
  • 數學性質:因為 KL 散度與最大化似然律有關,所以在機率建模中是一個合理的評估指標
  • 非對稱性:P 和 Q 之間的 KL 散度和 Q 和 P 之間的 KL 散度不同
  • 高效計算:對標準常態分佈來說,KL 散度可以直接計算

而 loss function 的另外一個部分,就是跟原本的 AutoEncoder 一樣的重建損失(reconstruction loss),我們可以直接透過 MSE 來計算輸入與最終輸出之間的落差,以此來改進。

所以整個 VAE 模型等於需要同時學習兩件事情:一個就是中間的隨機常態分佈要跟標準常態分佈相像、而通過編碼器、解碼器解碼出來的最終輸出需要與輸入越像越好。


VAE 其他特性

當然,除了以上重點之外,VAE 也還具備以下有趣的特色:

模型架構

VAE 可以使用任意深度學習神經網路層來實現,比如全連接層、CNN 和 RNN。實現的自由度非常高。但值得注意的是,在寫 loss function 時,有個乘以 KL 散度的 Beta 係數,該 Beta 係數的值越大,意味著模型越是關注中間 latent space 的常態分佈採樣需要越接近標準常態分佈越好,而 Beta 值越小,則是重建原始圖片的重要性越高。

這個 Beta 值並不好設,我反覆調整了幾次,這才讓模型在兩邊的 loss 達到平衡 —— 在測試的過程中,更是找到了 Beta-VAE 的架構,可以把 Beta 參數也變成模型學習的一部分,自動化地去設定。


生成能力

訓練結束後,VAE 的 decoder 部份可以被單獨提取之來,並使用我們從標準常態分佈中的採樣資料作為輸入,解碼出新的、從未見過的資料。


潛在空間(latent space)性質

VAE 的潛在空間有著良好的表現能力,比方說相似的資料會在潛在空間中叢聚在一起。


變體與發展

VAE 還有著許多的變體,比如 conditional VAE、Beta-VAE、disentangled VAE… 等等。這些都是為了解決特定的問題而被研究出的變體。

當然,VAE 也有其限制(尤其是最基礎的版本),比方說難以訓練或是生成的圖像較為模糊等等。


程式

以下來介紹我實作的程式部分。

首先是模型的架構。

from typing importTupleimport torch


class VAE(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        code_dim = 2# Encoder
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(784, 128),
            torch.nn.ReLU(),
        )

        # Mean & Variance
        self.mean_layer = torch.nn.Linear(128, code_dim)
        self.log_variance_layer = torch.nn.Linear(128, code_dim)

        # Decoder
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(code_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 784),
            torch.nn.Sigmoid(),
        )

    def reparameterization(self, means: torch.Tensor, log_variances: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_variances)
        eps = torch.randn_like(std)
        return means + eps * std

    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Encoder
        hiddens = self.encoder(inputs)

        # Mean & Variance
        means = self.mean_layer(hiddens)
        log_variances = self.log_variance_layer(hiddens)

        # Get gaussion distribution
        gaussian_distribution_codes = self.reparameterization(means=means, log_variances=log_variances)

        # Decoder
        decoded = self.decoder(gaussian_distribution_codes)

        return means, log_variances, gaussian_distribution_codes, decoded


訓練的程式,需要把剛才的 models.py 中定義好的模型架構 import 進來才行。

# coding: utf-8import torch
import torchvision
from models import VAE


def main() -> None:
    # Settings
    device = torch.device("cuda:0"if torch.cuda.is_available() else"cpu")
    epochs = 100
    batch_size = 128
    lr = 2e-3# DataLoader
    train_dataset = torchvision.datasets.MNIST(
        root="../../data/MNIST/",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Model
    model = VAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_function = torch.nn.MSELoss()

    # Trainfor epoch inrange(1, epochs+1):
        for inputs, labels in train_dataloader:
            inputs = inputs.view(-1, 784).to(device)

            # Forward
            means, log_vars, codes, decoded = model(inputs)

            # Backward
            optimizer.zero_grad()

            # Reconstruction loss
            reconstruction_loss = loss_function(decoded, inputs)

            # KL divergence loss
            KL_divergence = torch.mean(-0.5 * torch.sum(1 + log_vars - means ** 2 - log_vars.exp(), dim=1), dim=0)

            # Total loss
            beta = 0.001
            loss = reconstruction_loss + beta * KL_divergence
            loss.backward()

            optimizer.step()

        # Show progressprint(f"[{epoch}/{epochs}] Loss: {loss.item()}")

    # Save
    torch.save(model, "VAE.pth")


if __name__ == "__main__":
    main()


訓練結束後,我們也可以可以看到視覺化的結果。

# coding: utf-8import torch
import torchvision
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from models import VAE


def main() -> None:
    # Load model
    model = torch.load("VAE.pth")
    model.to("cuda:0").eval()
    print(model)

    # DataLoader
    test_dataset = torchvision.datasets.MNIST(
        root="../../data/MNIST/",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

    # Plot
    axis_x = []
    axis_y = []
    answers = []
    with torch.no_grad():
        for data in test_dataloader:
            inputs = data[0].view(-1, 28*28).to("cuda:0")
            answers += data[1].tolist()

            means, log_vars, code, outputs = model(inputs)
            axis_x += code[:, 0].tolist()
            axis_y += code[:, 1].tolist()


    # Use a colormap
    colors = [cm.rainbow(i/9) for i in answers]
    scatter = plt.scatter(axis_x, axis_y, c=colors)
    cbar = plt.colorbar(scatter)
    cbar.set_ticks([i/9for i inrange(10)])
    cbar.set_ticklabels([str(i) for i inrange(10)])
    plt.show()
    plt.savefig("VAE.png")


if __name__ == "__main__":
    main()

Output:


比起原本的 AutoEncoder,相同標籤的資料顯然更能表示群聚在一起的潛在空間特性了。

[Machine Learning] AutoEncoder 基本介紹 (附 PyTorch 程式碼)

我們也能透過 VAE,來畫出最具代表性的由潛在空間生成的圖像。因為這能呈現出潛在空間的連續性,相似的點會對應到相似的圖像。

這次的 VAE 實作算是一個很好玩的體驗,我以為我對 PyTorch 以及深度學習的概念已經足夠熟悉了,尤其是在今年做完 [論文閱讀] The Forward-Forward Algorithm: Some Preliminary Investigation 的實現後 —— 但實際上要不參考他人的實作,自己從零開始刻起還是跌跌撞撞地遇到許多問題,包括 loss function 的公式部分一開始我寫錯了好幾次。

研究所時有個同學似乎能花一天看論文,然後花了一天就能完全復現 VAE 的結果,現在想想真的是相當厲害,真的是該跟人家多學習學習才是。


References


Read More

Leave a Reply