Generative Adversarial Network (GAN) is a famous neural network model, its function is to input a set of noise and then generate a set of fake pictures through the Generator, and then use the Discriminator to distinguish whether it is a real picture.
First train a simple Discriminator classifier, and then train a Generator. The two models fight and train repeatedly, and finally take the trained Generator model. We can randomly generate noise and generate pictures at will!
The principle of GAN is actually so simple.
MNIST
MNIST is a very famous set of handwritten digits. Those who already know what Mnist dataset is can skip this section directly.
Its status can be said to be the Hello World in the machine learning world.
In the MNIST data set, there are 60,000 images of training data and 10,000 images of test data. I head that the total of 70,000 pictures are form high school students and census staff. The pixels of each picture are 28x28, and each pixel is represented by a grayscale value.
The precious point about this data set is that it already has the labelled label. They use one hot encoding to label 0 to 9.
If you want to know more about MNIST data set, you can go to this website: http://yann.lecun.com/exdb/mnist/
Model Definition
The following program I am used the PyTorch to do it. If you are interested in PyTorch, you can refer: [PyTorch] Tutorial(1) What is Tensor?
If you are interested in GAN, you can refer: [PyTorch] Tutorial(7) Use Deep Generative Adversarial Network (DCGAN) to generate pictures
I will explain the important code line by line below, and the complete code will be placed at the end.
First, I defined my model architecture in the model.py file.
# -*- coding: utf-8 -*- import torch.nn as nn class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.main = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, input): return self.main(input) class generator(nn.Module): def __init__(self): super(generator, self).__init__() self.main = nn.Sequential( nn.Linear(128, 1024), nn.ReLU(), nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 784), nn.Tanh() ) def forward(self, input): return self.main(input)
In the part of Discriminator, because it is a classifier for the MNIST pictures, the initial tensor input size is 784 (the picture size of Mnist is 28x28), and then the number of fully-connected neurons is reduced all the way, and only one final neuron uses the sigmoid activation function to output. (sigmoid function you can refer: [PyTorch] Set the threshold of Sigmoid output and convert it to binary value)
Generator is to input noise and then generate 28x28 picture, the same, I only use fully-connected layer in here.
Need to declare in advance, I think this is not necessarily the best model configuration, you can refer to it can try different configurations, which is quite interesting.
Training
Import all the packages we will use.
# -*- coding: utf-8 -*- import time import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import transforms from model import discriminator, generator import numpy as np import matplotlib.pyplot as plt
In here I save the start time and set the plot function, if you don't want to know, you can delete this line.
start_time = time.time() plt.rcParams['image.cmap'] = 'gray' def show_images(images): sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) for index, image in enumerate(images): plt.subplot(sqrtn, sqrtn, index+1) plt.imshow(image.reshape(28, 28))
In here I define the Loss Function of Discriminator and Generator respectively. The Loss of the Discriminator is the distance between the "model prediction result" and the "actual answer", and the Loss of the Generator is basically the "number of True categories obtained"-that is, the more discriminator we have trained, the better.
# Discriminator Loss => BCELoss def d_loss_function(inputs, targets): return nn.BCELoss()(inputs, targets) def g_loss_function(inputs): targets = torch.ones([inputs.shape[0], 1]) targets = targets.to(device) return nn.BCELoss()(inputs, targets)
Next, set some parameters required by Learning and read in Training Data-this is mainly used for Discriminator training, and Generator training only needs to randomly generate Noise.
# GPU device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('GPU State:', device) # Model G = generator().to(device) D = discriminator().to(device) print(G) print(D) # Settings epochs = 200 lr = 0.0002 batch_size = 64 g_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) d_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) # Transform transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Load data train_set = datasets.MNIST('mnist/', train=True, download=True, transform=transform) test_set = datasets.MNIST('mnist/', train=False, download=True, transform=transform) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
Training code:
# Train for epoch in range(epochs): epoch += 1 for times, data in enumerate(train_loader): times += 1 real_inputs = data[0].to(device) test = 255 * (0.5 * real_inputs[0] + 0.5) real_inputs = real_inputs.view(-1, 784) real_outputs = D(real_inputs) real_label = torch.ones(real_inputs.shape[0], 1).to(device) noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device) outputs = torch.cat((real_outputs, fake_outputs), 0) targets = torch.cat((real_label, fake_label), 0) # Zero the parameter gradients d_optimizer.zero_grad() # Backward propagation d_loss = d_loss_function(outputs, targets) d_loss.backward() d_optimizer.step() # Generator noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) g_loss = g_loss_function(fake_outputs) g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if times % 100 == 0 or times == len(train_loader): print('[{}/{}, {}/{}] D_loss: {:.3f} G_loss: {:.3f}'.format(epoch, epochs, times, len(train_loader), d_loss.item(), g_loss.item())) imgs_numpy = (fake_inputs.data.cpu().numpy()+1.0)/2.0 show_images(imgs_numpy[:16]) plt.show() if epoch % 50 == 0: torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) print('Model saved.') print('Training Finished.') print('Cost Time: {}s'.format(time.time()-start_time))
Output:
ๆฒๆๅฐๅฎ็พ๏ผๅคงๅฎถๅฏไปฅ่ฉฆ่ฉฆ็ไธๅ็้ ็ฝฎใ
Complete Code
model.py
# -*- coding: utf-8 -*- import torch.nn as nn class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.main = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, input): return self.main(input) class generator(nn.Module): def __init__(self): super(generator, self).__init__() self.main = nn.Sequential( nn.Linear(128, 1024), nn.ReLU(), nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 784), nn.Tanh() ) def forward(self, input): return self.main(input)
mnist_train.py
# -*- coding: utf-8 -*- import time import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import transforms from model import discriminator, generator import numpy as np import matplotlib.pyplot as plt start_time = time.time() plt.rcParams['image.cmap'] = 'gray' def show_images(images): sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) for index, image in enumerate(images): plt.subplot(sqrtn, sqrtn, index+1) plt.imshow(image.reshape(28, 28)) # Discriminator Loss => BCELoss def d_loss_function(inputs, targets): return nn.BCELoss()(inputs, targets) def g_loss_function(inputs): targets = torch.ones([inputs.shape[0], 1]) targets = targets.to(device) return nn.BCELoss()(inputs, targets) # GPU device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('GPU State:', device) # Model G = generator().to(device) D = discriminator().to(device) print(G) print(D) # Settings epochs = 200 lr = 0.0002 batch_size = 64 g_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) d_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) # Transform transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Load data train_set = datasets.MNIST('mnist/', train=True, download=True, transform=transform) test_set = datasets.MNIST('mnist/', train=False, download=True, transform=transform) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False) # Train for epoch in range(epochs): epoch += 1 for times, data in enumerate(train_loader): times += 1 real_inputs = data[0].to(device) test = 255 * (0.5 * real_inputs[0] + 0.5) real_inputs = real_inputs.view(-1, 784) real_outputs = D(real_inputs) real_label = torch.ones(real_inputs.shape[0], 1).to(device) noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device) outputs = torch.cat((real_outputs, fake_outputs), 0) targets = torch.cat((real_label, fake_label), 0) # Zero the parameter gradients d_optimizer.zero_grad() # Backward propagation d_loss = d_loss_function(outputs, targets) d_loss.backward() d_optimizer.step() # Generator noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) g_loss = g_loss_function(fake_outputs) g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if times % 100 == 0 or times == len(train_loader): print('[{}/{}, {}/{}] D_loss: {:.3f} G_loss: {:.3f}'.format(epoch, epochs, times, len(train_loader), d_loss.item(), g_loss.item())) imgs_numpy = (fake_inputs.data.cpu().numpy()+1.0)/2.0 show_images(imgs_numpy[:16]) plt.show() if epoch % 50 == 0: torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) print('Model saved.') print('Training Finished.') print('Cost Time: {}s'.format(time.time()-start_time))
test.py
import torch from torchvision import transforms import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec plt.rcParams['figure.figsize'] = (10.0, 8.0) plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' def show_images(images): sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) for index, image in enumerate(images): plt.subplot(sqrtn, sqrtn, index+1) plt.imshow(image.reshape(28, 28)) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('GPU State:', device) # Model G = torch.load('Generator_epoch_200.pth') G.eval() # Generator noise = (torch.rand(16, 128)-0.5) / 0.5 noise = noise.to(device) fake_image = G(noise) imgs_numpy = (fake_image.data.cpu().numpy()+1.0)/2.0 show_images(imgs_numpy) plt.show()
If you think the webpage looks inconvenient, you can also check it out on my Github: https://github.com/ccs96307/PyTorch-Mnist-GAN