**-Minst image generation by GAN --Introducing the implementation method using keras **
Generative adversarial network, or GAN. I often hear that it's really popular, but when you actually try to implement it yourself, it's quite a threshold.
It's a technology that seems to be important for me, so I left it alone just by looking at it from the outside. Surprisingly, there are quite a lot of people like that.
** This time, I will introduce an example of implementing such GAN using mnist data. ** ** Data and code ["Unsupervised learning with python"](url https://www.amazon.co.jp/Python%E3%81%A7%E3%81%AF%E3%81%98%E3%82 % 81% E3% 82% 8B% E6% 95% 99% E5% B8% AB% E3% 81% AA% E3% 81% 97% E5% AD% A6% E7% BF% 92-% E2% 80% 95% E6% A9% 9F% E6% A2% B0% E5% AD% A6% E7% BF% 92% E3% 81% AE% E5% 8F% AF% E8% 83% BD% E6% 80% A7% E3% 82% 92% E5% BA% 83% E3% 81% 92% E3% 82% 8B% E3% 83% A9% E3% 83% 99% E3% 83% AB% E3% 81% AA% E3% 81% 97% E3% 83% 87% E3% 83% BC% E3% 82% BF% E3% 81% AE% E5% 88% A9% E7% 94% A8-Ankur-Patel / dp / 4873119103) I am allowed to.
The book I referred to was written using object-oriented programming, so it was a little high level, but it was a great learning experience.
Similarly, I hope it will be helpful for beginners.
Here are the results I got first. Because it has an impact on the appearance.
** Genuine **
** Generate **
** I feel that the generated images are disgustingly similar ... ** If you learn longer, you may be able to do better things.
Here is a brief overview. Please refer to this article for details. GAN: What is a hostile generation network? Image generation by "unsupervised learning" https://www.imagazine.co.jp/gan%EF%BC%9A%E6%95%B5%E5%AF%BE%E7%9A%84%E7%94%9F%E6%88%90%E3%83%8D%E3%83%83%E3%83%88%E3%83%AF%E3%83%BC%E3%82%AF%E3%81%A8%E3%81%AF%E4%BD%95%E3%81%8B%E3%80%80%EF%BD%9E%E3%80%8C%E6%95%99%E5%B8%AB/
** With GAN, you can learn a dataset and create data that looks just like the same dataset. ** ** In the example of the reference article, GAN is used to generate a photo of the bedroom that does not actually exist. It's hard to tell, machine learning is scary.
Since this article uses mnist, we will generate handwritten characters. How do you generate this handwritten character?
** At GAN, there are two models, one that generates the data and one that identifies the data. In the model that generates data, we will create data that looks like handwritten characters. Then, the created data is used as an identification model to determine whether it is fake or genuine. Then, based on the result, we will train the generative model and then create an image that is closer to the real thing. ** **
Simply put, it's just this model. The only question that remains is how to train the data and how to train the data. I think.
In this model, data training is performed as follows.
**-Generate an image (1 * 28 * 28) from noise (100 * 1 * 1) with a generative model --Learning the discriminative model with "actual image" and "image created by generative model" --Create a new image from the generative model. The generative model and the discriminative model are trained so that the generated image is classified as a "real image" in the discriminative model. ** **
We will actually implement this model.
It's just a reference book, but I've improved it a little so that it can be used with google colab.
python
'''Main'''
import numpy as np
import pandas as pd
import os, time, re
import pickle, gzip, datetime
'''Data Viz'''
import matplotlib.pyplot as plt
import seaborn as sns
color = sns.color_palette()
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import Grid
%matplotlib inline
'''Data Prep and Model Evaluation'''
from sklearn import preprocessing as pp
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import log_loss, accuracy_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import roc_curve, auc, roc_auc_score, mean_squared_error
from keras.utils import to_categorical
'''Algos'''
import lightgbm as lgb
'''TensorFlow and Keras'''
import tensorflow as tf
import keras
from keras import backend as K
from keras.models import Sequential, Model
from keras.layers import Activation, Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.layers import LeakyReLU, Reshape, UpSampling2D, Conv2DTranspose
from keras.layers import BatchNormalization, Input, Lambda
from keras.layers import Embedding, Flatten, dot
from keras import regularizers
from keras.losses import mse, binary_crossentropy
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.optimizers import Adam, RMSprop
from keras.datasets import mnist
sns.set("talk")
It is reading data. Use minst data. It is intended for use in colaboratory. Since we only use x_train, we only normalize reshpae and 0 to 1 values to x_train.
python
#Data divided into training data and test data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1))
#Pixel value 0~Normalize between 1
x_train= x_train / 255.0
Super important DCGAN code. It is defined by a class that summarizes the generative model and the discriminative model. To briefly describe the function of each function generator --Neural network that converts 100 * 1 * 1 vector to 28 * 28 * 1 image ――By learning this, an image like that will be generated.
discriminator --A neural network that identifies whether an image of 28 * 28 * 1 is genuine or fake
discriminator_model --Compile and model a neural network for identification
adversarial?model --A model created by connecting generaor and discriminator --Train the generated network with this model
python
#DCGAN class
class DCGAN(object):
#Initialization
def __init__(self, img_rows=28, img_cols=28, channel=1):
self.img_rows = img_rows
self.img_cols = img_cols
self.channel = channel
self.D = None # discriminator
self.G = None # generator
self.AM = None # adversarial model
self.DM = None # discriminator model
#Generation network
#100*1*The matrix of 1 is the same as the image in the dataset 1*28*28
def generator(self, depth=256, dim=7, dropout=0.3, momentum=0.8, \
window=5, input_dim=100, output_depth=1):
if self.G:
return self.G
self.G = Sequential()
#100*1*1 → 256*7*7
self.G.add(Dense(dim*dim*depth, input_dim=input_dim))
self.G.add(BatchNormalization(momentum=momentum))
self.G.add(Activation('relu'))
self.G.add(Reshape((dim, dim, depth)))
self.G.add(Dropout(dropout))
#256*7*7 → 128*14*14
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/2), window, padding='same'))
self.G.add(BatchNormalization(momentum=momentum))
self.G.add(Activation('relu'))
#128*14*14 → 64*28*28
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/4), window, padding='same'))
self.G.add(BatchNormalization(momentum=momentum))
self.G.add(Activation('relu'))
#64*28*28→32*28*28
self.G.add(Conv2DTranspose(int(depth/8), window, padding='same'))
self.G.add(BatchNormalization(momentum=momentum))
self.G.add(Activation('relu'))
#1*28*28
self.G.add(Conv2DTranspose(output_depth, window, padding='same'))
#Set each pixel to a value between 0 and 1
self.G.add(Activation('sigmoid'))
self.G.summary()
return self.G
#Identification network
#28*28*Distinguish whether the image of 1 is genuine
def discriminator(self, depth=64, dropout=0.3, alpha=0.3):
if self.D:
return self.D
self.D = Sequential()
input_shape = (self.img_rows, self.img_cols, self.channel)
#28*28*1 → 14*14*64
self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,padding='same'))
self.D.add(LeakyReLU(alpha=alpha))
self.D.add(Dropout(dropout))
#14*14*64 → 7*7*128
self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))
self.D.add(LeakyReLU(alpha=alpha))
self.D.add(Dropout(dropout))
#7*7*128 → 4*4*256
self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
self.D.add(LeakyReLU(alpha=alpha))
self.D.add(Dropout(dropout))
#4*4*512 → 4*4*512 ####However, check if it matches###
self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))
self.D.add(LeakyReLU(alpha=alpha))
self.D.add(Dropout(dropout))
#Flatten and classify by sigmoid
self.D.add(Flatten())
self.D.add(Dense(1))
self.D.add(Activation('sigmoid'))
self.D.summary()
return self.D
#Discriminative model
def discriminator_model(self):
if self.DM:
return self.DM
optimizer = RMSprop(lr=0.0002, decay=6e-8)
self.DM = Sequential()
self.DM.add(self.discriminator())
self.DM.compile(loss='binary_crossentropy', \
optimizer=optimizer, metrics=['accuracy'])
return self.DM
#Generative model
def adversarial_model(self):
if self.AM:
return self.AM
optimizer = RMSprop(lr=0.0001, decay=3e-8)
self.AM = Sequential()
self.AM.add(self.generator())
self.AM.add(self.discriminator())
self.AM.compile(loss='binary_crossentropy', \
optimizer=optimizer, metrics=['accuracy'])
return self.AM
Next, we will use these functions to actually train the minst data to generate an image. Train the image with the train function and save the image with plot_images.
The train function is executed in the following flow.
**-Generate training data from noise --Apply the generated data to the discriminative model. Save how well you could identify at this time in D_loss. --Learning with adversarial_model so that the generated data looks real. Save how much you were deceived at this time in A_loss. ** **
python
#A class that applies DCGAN to MNIST data
class MNIST_DCGAN(object):
#Initialization
def __init__(self, x_train):
self.img_rows = 28
self.img_cols = 28
self.channel = 1
self.x_train = x_train
#Identification of DCGAN, definition of hostile generative model
self.DCGAN = DCGAN()
self.discriminator = self.DCGAN.discriminator_model()
self.adversarial = self.DCGAN.adversarial_model()
self.generator = self.DCGAN.generator()
#Training function
#train_on_batch is learning for each batch. Output is loss and acc
def train(self, train_steps=2000, batch_size=256, save_interval=0):
noise_input = None
if save_interval>0:
noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
for i in range(train_steps):
#Batch training data_Randomly take out only size
images_train = self.x_train[np.random.randint(0,self.x_train.shape[0], size=batch_size), :, :, :]
# 100*1*Generate noise of 1 by batch size and make it a fake image
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
#Train the generated image
images_fake = self.generator.predict(noise)
x = np.concatenate((images_train, images_fake))
#Set the training data to 1 and the generated data to 0
y = np.ones([2*batch_size, 1])
y[batch_size:, :] = 0
#Train the discriminative model
d_loss = self.discriminator.train_on_batch(x, y)
y = np.ones([batch_size, 1])
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
#Generate&Train the discriminative model
#Training of the generative model is done only here
a_loss = self.adversarial.train_on_batch(noise, y)
#Loss and accuracy of training data and generative models
#D loss is the loss and acc of the generated image and the actual image
#A loss is loss and acc when the image generated by adversarial is 1.
log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
log_mesg = "%s [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
print(log_mesg)
#save_Save data for each interval
if save_interval>0:
if (i+1)%save_interval==0:
self.plot_images(save2file=True, \
samples=noise_input.shape[0],\
noise=noise_input, step=(i+1))
#Plot training results
def plot_images(self, save2file=False, fake=True, samples=16, \
noise=None, step=0):
current_path = os.getcwd()
file = os.path.sep.join(["","data", 'images', 'chapter12', 'synthetic_mnist', ''])
filename = 'mnist.png'
if fake:
if noise is None:
noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
else:
filename = "mnist_%d.png " % step
images = self.generator.predict(noise)
else:
i = np.random.randint(0, self.x_train.shape[0], samples)
images = self.x_train[i, :, :, :]
plt.figure(figsize=(10,10))
for i in range(images.shape[0]):
plt.subplot(4, 4, i+1)
image = images[i, :, :, :]
image = np.reshape(image, [self.img_rows, self.img_cols])
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
if save2file:
plt.savefig(current_path+file+filename)
plt.close('all')
else:
plt.show()
GAN is amazing. I even feel uncomfortable when something like handwritten characters is generated.
It seems that it can be used for abnormality detection etc. at the actual site.
However, in the summary of the reference books, there was a statement that ** "Please be prepared for a great deal of effort when using GAN" **. There was no detailed reason for that ...
How hard do you have, GAN.
Until the end Thank you for reading.
Recommended Posts