Skip to content

[PyTorch] Image: 利用生成對抗網路 DCGAN 生成圖片

今天我將紀錄如何使用 DCGAN 來實做簡單的『生成圖片』模型。本來我想要用美味的點心圖片來示範(我確實地下載了五十萬張點心圖片),但奈何效果不怎麼好,最後還是用回了官方示範的 CelebA。

以下的程式碼,已經與官方的有些差別了,但基本上最初還是參考官方的模型設計來寫。如果想直接看官方的教學,也許你可以參考: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html


CelebA 介紹

Large-scale CelebFaces Attributes (CelebA),為著名的名人臉部圖片資料集,並且有用 Bounding Box 來標注臉部,是由香港大學的 Multimedia Lab 建立。

這個 datasets 一共有 10177 個人物、202599 張臉部圖片、 每張圖片皆為 178 x 218 解析度。

20 萬張圖片也算是相當多了。這是我們生成對抗網路的 Training data。


DCGAN 簡介

DCGAN 現在普遍被認為是 GAN (生成對抗網路) 的擴展,全名為 Deep Convolutional Generative Adversarial Network。

顧名思義,便是將 CNN 的概念加入了生成對抗網路。這種模型訓練的架構基本上分為 Generator 以及 Discriminator 兩種模型。

Generator 負責從真實的照片中產生圖片,這些圖片全部都會被標上『fake』的標籤(通常為 0),而另一方面, Discriminator 則一個二元分類器,負責判斷真正的照片以及虛假的照片。

常見的作法為先訓練其中一個模型、然後在模型的 Loss 較低時換成另外一個模型,然後另一個模型訓練至 Loss 較低時再切換回來原本的模型繼續往下訓練 ….. 就這樣,兩個模型彼此競爭,漸漸地, Generator 產生出來的圖片就會開始讓 Discriminator 難辨真假,我們就達成了我們的目的 —— 產生真正有用的生成圖片模型。

不過,以上都是理想狀況。

真實世界裡,GAN (生成對抗網路)是個不好控制的模型,有可能會變成 Discriminator 變得分辨率太過強大,漸漸導致 Generator 的權重不管怎麼針對 Loss 進行 backward propagation 都沒有用,於是變成 Training 越來越失去效果 —— 反正 Discriminator 都會分辨為假照片,Generator 的權重當然怎麼 backward propagation 都沒差!

所以整體而言,GAN 是個很花費精神的模型,最好每個階段都儲存個模型下來,並時刻關切 Discriminator 以及 Generator 的 Loss 有沒有呈現改進,也許輔以 TensorBoard 是個不錯的想法。


資料準備

資料需要準備的格式為:

celeba
|__ img_align_celeba

img_align_celeba 裡面才是我們所有的 jpg 檔,img_align_celeba 外面那層的資料夾才是我們要賦予給程式碼 DataLoader 的路徑。要小心不要搞錯了,我剛開始花了一些時間才釐清我弄錯的路徑。

下載的資料可以參考:這裡


Discriminator.py

# -*- coding: utf-8 -*-
import torch.nn as nn


# Discriminator
class Discriminator(nn.Module):
    def __init__(self, inputSize, hiddenSize):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(inputSize, hiddenSize, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(hiddenSize, hiddenSize*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hiddenSize*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(hiddenSize*2, hiddenSize*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hiddenSize*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(hiddenSize*4, hiddenSize*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hiddenSize*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(hiddenSize*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid())

    def forward(self, input):
        return self.main(input)


這算也算是很經典的 CNN 模型,Convolution 層加上正規化的 BatchNorm,再使用激活函數 LeakyReLU 輸出。

印出模型後會得到以下的文字:

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

Generator.py

# -*- coding: utf-8 -*-
import torch.nn as nn


# Generator
class Generator(nn.Module):
    def __init__(self, inputSize, hiddenSize, outputSize):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(inputSize, hiddenSize*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(hiddenSize*8),
            nn.ReLU(True),

            nn.ConvTranspose2d(hiddenSize*8, hiddenSize*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hiddenSize*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(hiddenSize*4, hiddenSize*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hiddenSize*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(hiddenSize*2, hiddenSize, 4, 2, 1, bias=False),
            nn.BatchNorm2d(hiddenSize),
            nn.ReLU(True),

            nn.ConvTranspose2d(hiddenSize, outputSize, 4, 2, 1, bias=False),
            nn.Tanh())

    def forward(self, input):
        return self.main(input)



Generator 這邊有個比較少見到的模型層: ConvTranspose2d()。基本上常被譯作『轉置卷積』以及『反卷積』。由於 Generator 是接受一些隨機取樣的 Noise 來當作輸入並希望能產生出一張圖片,故需要用這種模型層加上 backward propagation 調整權重來讓那些 Noise 能真正形成一張常見的照片。


Train.py

以上兩種模型都定義好了之後,我們該來進行訓練的部份了。

# -*- coding: utf-8 -*-
import random
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from DCGAN.generator import Generator
from DCGAN.discriminator import Discriminator


匯入我們需要的 Package。 DCGAN.generator 以及 DCGAN.discriminator 是我兩個模型撰寫的位置,只要你能夠正常地匯入,隨便怎麼設定都可以。

# CUDA
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)


# Random seed
manualSeed = 7777
print('Random Seed:', manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)



確認 GPU 是否可用 (這次不用 GPU 是真的有點慢。)

另外,固定 Seed,不管是 Numpy 或是 torch 本身的 Seed。

# Attributes
dataroot = 'celeba'

batch_size = 1024
image_size = 64
G_out_D_in = 3
G_in = 100
G_hidden = 64
D_hidden = 64

epochs = 5
lr = 0.001
beta1 = 0.5



參數設定,這個大家可以任意調整。

# Data
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                           ]))

# Create the dataLoader
dataLoader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)



正如同上面所說,要注意 root 的路徑為我們圖片的資料夾的上一層資料夾。

# Weights
def weights_init(m):
    classname = m.__class__.__name__
    print('classname:', classname)

    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)



初始化權重!這裡其實是按照最初提出 DCGAN 論文的要求來處理的。你也可以拿掉,試試看效果。

# Train
def train():
    # Create the generator
    netG = Generator(G_in, G_hidden, G_out_D_in).to(device)
    netG.apply(weights_init)
    print(netG)

    # Create the discriminator
    netD = Discriminator(G_out_D_in, D_hidden).to(device)
    netD.apply(weights_init)
    print(netD)

    # Loss fuG_out_D_intion
    criterion = nn.BCELoss()
    fixed_noise = torch.randn(64, G_in, 1, 1, device=device)

    real_label = 1
    fake_label = 0
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    img_list = []
    G_losses = []
    D_losses = []
    iters = 0
    print('Start!')

    for epoch in range(epochs):
        for i, data in enumerate(dataLoader, 0):
            # Update D network
            netD.zero_grad()
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, device=device)
            output = netD(real_cpu).view(-1)

            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, G_in, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            output = netD(fake.detach()).view(-1)

            errD_fake = criterion(output, label)
            errD_fake.backward()

            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            # Update G network
            netG.zero_grad()
            label.fill_(real_label)
            output = netD(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            # Output training stats
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, epochs, i, len(dataLoader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 500 == 0) or ((epoch == epochs - 1) and (i == len(dataLoader) - 1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()

                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1

    torch.save(netD, 'netD.pkl')
    torch.save(netG, 'netG.pkl')

    return G_losses, D_losses



真正的 Training,可以看到我們為了節省麻煩,將 Discriminator 和 Generator 一起訓練,這就是 PyTorch Tutorial 裡面的作法。

最後,我們來畫出我們兩種模型的 Loss,藉此觀察狀況,並且也印出真假圖片的比較:

# Plot
def plotImage(G_losses, D_losses):
    print('Start to plot!!')
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    # Grab a batch of real images from the dataloader
    real_batch = next(iter(dataLoader))

    # Plot the real images
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))

    # Plot the fake images from the last epoch
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    plt.show()



Output:

也許從小圖看來 Fake 圖片還挺像人臉的 —— 但若放大來看會覺得這完成度還遠遠不足。

希望日後還有機會能夠紮實地研究一下這個模型該怎麼改進,不管怎麼說,圖像有關的東西都相當挺有趣的!


Read More

Leave a Reply