This is a memo of the record when DCGAN, which is a derivative of GAN, was performed using Tensorflow. I will explain it roughly without going too deep. I wrote almost the same article the other day, but it got messed up, so I'll reorganize it a little.
I want to generate sea slugs in the title! It says, but at first I was thinking of generating Pokemon with DCGAN. So for the time being, I think I'll write it briefly from the attempt to generate Pokemon.
By the way, sea slugs are such creatures. There are many colorful types and they are beautiful
I would like to write about GAN briefly. GAN is a person who learns two things, ** "Generator" that creates fake and ** "Discriminator" that discriminates **, and generates data that is as close to the real thing as possible. Generator creates a new image from random noise with reference to real data. Discriminator discriminates the image generated by Generator as "fake or genuine". Generator and Discriminator are good rivals. By repeating this over and over again, the Generator and Discriminator will become smarter and smarter. As a result, images that are close to real data will be generated.
↓ It looks like this
↓ What I wrote more easily
This is the basic mechanism of GAN. DCGAN is the one that uses CNN (Convolutional Neural Network) for this GAN. CNN is complicated in various ways, but to put it simply, it is possible to share weights between neural networks by making the neural network a multi-layer structure using two layers, ** convolution layer ** and ** pooling layer **. Will be possible. As a result, DCGAN can perform learning with higher efficiency and accuracy than GAN.
I will use this DCGAN to generate Pokemon and sea slugs. Also, the explanation of GAN and DCGAN GAN (1) Understanding the basic structure that I can't hear anymore GAN that I can't hear anymore (2) Image generation by DCGAN Is easy to understand.
There are many types of Pokemon, and I chose it as a theme because I thought it would be fun with a familiar theme. There are so many types of Pokemon today. Pokemon image is [here](https://kamigame.jp/%E3%83%9D%E3%82%B1%E3%83%A2%E3%83%B3USUM/%E3%83%9D%E3%82 % B1% E3% 83% A2% E3% 83% B3 /% E8% 89% B2% E9% 81% 95% E3% 81% 84% E3% 83% 9D% E3% 82% B1% E3% 83% Downloaded from A2% E3% 83% B3% E4% B8% 80% E8% A6% A7.html).
By the way, this time, I used the Chrome extension ** "Image Downloader" ** to collect Pokemon images. It is recommended because it can be used easily without writing code. I thought that the number of data was too small, so I added rotation and inversion with the following code and inflated it. By the way, it is saved in .npy
format for easy reading.
import os,glob
import numpy as np
from tqdm import tqdm
from keras.preprocessing.image import load_img,img_to_array
from keras.utils import np_utils
from sklearn import model_selection
from PIL import Image
#Store classes in an array
classes = ["class1", "class2"]
num_classes = len(classes)
img_size = 128
color=False
#Loading images
#Finally images and labels are stored in the list
temp_img_array_list=[]
temp_index_array_list=[]
for index,classlabel in enumerate(classes):
photos_dir = "./" + classlabel
#Get a list of images for each class with glob
img_list = glob.glob(photos_dir + "/*.jpg ")
for img in tqdm(img_list):
temp_img=load_img(img,grayscale=color,target_size=(img_size, img_size))
temp_img_array=img_to_array(temp_img)
temp_img_array_list.append(temp_img_array)
temp_index_array_list.append(index)
#Rotation processing
for angle in range(-20,20,5):
#rotation
img_r = temp_img.rotate(angle)
data = np.asarray(img_r)
temp_img_array_list.append(data)
temp_index_array_list.append(index)
#Invert
img_trans = img_r.transpose(Image.FLIP_LEFT_RIGHT)
data = np.asarray(img_trans)
temp_img_array_list.append(data)
temp_index_array_list.append(index)
X=np.array(temp_img_array_list)
Y=np.array(temp_index_array_list)
np.save("./img_128RGB.npy", X)
np.save("./index_128RGB.npy", Y)
I wanted to make a chimeric Pokemon by mixing Pokemon full of DCGAN
↓
But what I actually did was
It was clearly overfitted, as you can see from both the generated image and the loss. Discriminator is insanely strong. So next, I thought about the cause and solved it.
-Since Pokemon have different colors and shapes, is it easy to generate chaotic guys? ――I want to use something that has a unified shape to some extent. Now change from Pokemon generation to ** Nudibranch generation **. ――However, the color and shape of sea slugs are not so unified, so I feel that the subject matter is delicate. But I'll tell you that making something you like keeps you motivated.
――We collected about 500+ images of sea slugs from Pokemon images. Rotation (-20 ° ~ 20 °) and inversion will probably increase 16 times, so the amount of data has increased by ** "500 x 16 = 8000" **. --Images were collected by ** Flickr ** and ** icrawler **. ――I'll roughly explain how to use Flickr. Go to the Flickr API site (https://www.flickr.com/services/api/) where it says ** API key **. If you get a Yahoo account here and log in, this screen will appear, so get the key from here. (It's painted black) Use this key to get the image with the code below
from flickrapi import FlickrAPI
from urllib.request import urlretrieve
from pprint import pprint
import os, time, sys
#AP key I information
key = "********"
secret = "********"
wait_time = 1
#Specify save folder
savedir = "./gazou"
flickr = FlickrAPI(key, secret, format="parsed-json")
result = flickr.photos.search(
per_page = 100,
tags = "seaslug",
media = "photos",
sort = "relevance",
safe_search = 1,
extras = "url_q, licence"
)
photos = result["photos"]
#Store information in photo by loop processing
for i, photo in enumerate(photos['photo']):
url_q = photo["url_q"]
filepath = savedir + "/" + photo["id"] + ".jpg "
if os.path.exists(filepath): continue
urlretrieve(url_q, filepath)
time.sleep(wait_time)
This will collect some data, but I wanted more, so I will collect images with ** icrawler **. It's insanely easy to use.
$ pip install icrawler
from icrawler.builtin import GoogleImageCrawler
crawler = GoogleImageCrawler(storage={"root_dir": "gazou"})
crawler.crawl(keyword="Nudibranch", max_num=100)
This alone will save the sea slug image in the specified folder. Like Pokemon, this image was inflated by rotating and flipping it.
--To briefly explain the dropout, overfitting is prevented by ignoring the set ratio of nodes. --For details, this article seems to be good. ――The following is the actual Discriminator with the dropout applied.
def discriminator(x, reuse=False, alpha=0.2):
with tf.variable_scope("discriminator", reuse=reuse):
x1 = tf.layers.conv2d(x, 32, 5, strides=2, padding="same")
x1 = tf.maximum(alpha * x1, x1)
x1_drop = tf.nn.dropout(x1, 0.5)
x2 = tf.layers.conv2d(x1_drop, 64, 5, strides=2, padding="same")
x2 = tf.layers.batch_normalization(x2, training=True)
x2 = tf.maximum(alpha * x2, x2)
x2_drop = tf.nn.dropout(x2, 0.5)
x3 = tf.layers.conv2d(x2_drop, 128, 5, strides=2, padding="same")
x3 = tf.layers.batch_normalization(x3, training=True)
x3 = tf.maximum(alpha * x3, x3)
x3_drop = tf.nn.dropout(x3, 0.5)
x4 = tf.layers.conv2d(x3_drop, 256, 5, strides=2, padding="same")
x4 = tf.layers.batch_normalization(x4, training=True)
x4 = tf.maximum(alpha * x4, x4)
x4_drop = tf.nn.dropout(x4, 0.5)
x5 = tf.layers.conv2d(x4_drop, 512, 5, strides=2, padding="same")
x5 = tf.layers.batch_normalization(x5, training=True)
x5 = tf.maximum(alpha * x5, x5)
x5_drop = tf.nn.dropout(x5, 0.5)
flat = tf.reshape(x5_drop, (-1, 4*4*512))
logits = tf.layers.dense(flat, 1)
logits_drop = tf.nn.dropout(logits, 0.5)
out = tf.sigmoid(logits_drop)
return out, logits
――If the learning rate is high, the training will proceed quickly, but it will easily diverge and it will be difficult to learn. ――When I actually verified with various values starting from 1e-2, is 1e-4 just right? It was like that. In my case, learning was too slow at 1e-5. --For the behavior of various learning rates, this article is easy to understand.
――Initially, it was about 8: 2, but it was changed to 6: 4. I couldn't really feel the effect
100epoch
200epoch
300epoch
400epoch
500epoch
――For the time being, I turned it around 500 epoch. Looking at it from a distance, I feel that sea slugs are being produced. ――But honestly, the result is subtle ... ――Possible factors are "Is there enough epoch?" "Is the image containing too much extra (rocky background, etc.)?" "Is the layer too deep?" "After all, the image is a little simpler. Various things can be considered, such as "Is it good?" ――I wanted to improve it further and turn it a little more, but it is running on ** Google Colaboratory **, and it is quite difficult due to the connection time. --There are a few things I would like to write about Colaboratory, so I will set up a chapter next.
Colaboratory Colaboratory is a Jupyter notebook environment that runs on the cloud provided by Google, and you can use a GPU of about 800,000 yen. Moreover, there is no need to build an environment or apply for Datalab. Further free. It's insanely convenient, but it has the following restrictions.
--If you connect to the GPU for a certain amount of time a day (recently about 4 hours [500 epoch]), you will not be able to use that day. (This is due to the lack of GPU resources in Colaboratory, so there is no workaround and there is no choice but to wait. GPU is said to be preferentially assigned to users who are not constantly using it.)
--The runtime is disconnected after 90 minutes when inactive, up to 12 hours, and the learning results of the notebook are also initialized.
――Therefore, I used ** Hyperdash ** to solve the 90-minute problem. This allows you to connect the runtime for over 90 minutes.
--In addition to the 90-minute problem by sending the learning log to Hyperdash, you can also solve the Buffered data was truncated after reaching the output size limit.
Problem that makes it impossible to check the log on Colaboratory.
――Hyperdash is a smartphone app, so you can check the log even when you're on the go, which is convenient.
--Hyperdash allows you to check the plots and parameters of the learning progress, but this time the purpose is only to prevent runtime disconnection, so just the steps below are OK.
#First, start the smartphone app Hyperdash and create an account.
#Install Hyperdash
!pip install hyperdash
from hyperdash import monitor_cell
!hyperdash login --email
You will be asked for your Hyperdash email address and password, so enter them. Next, write the code that uses Hyperdash and it's OK.
#Using Hyperdash
from tensorflow.keras.callbacks import Callback
from hyperdash import Experiment
class Hyperdash(Callback):
def __init__(self, entries, exp):
super(Hyperdash, self).__init__()
self.entries = entries
self.exp = exp
def on_epoch_end(self, epoch, logs=None):
for entry in self.entries:
log = logs.get(entry)
if log is not None:
self.exp.metric(entry, log)
exp = Experiment("Any name")
hd_callback = Hyperdash(["val_loss", "loss", "val_accuracy", "accuracy"], exp)
~~~Training execution code~~~
exp.end()
Now, if you look at the smartphone app Hyperdash, you should see the learning log.
Using Hyperdash solved the problem for 90 minutes, but for some reason the runtime may be disconnected, so I think it's a good idea to divide the training into smaller pieces and save them as .ckpt
. This .ckpt
also disappears when the runtime is disconnected, so save it early.
#Learning results.Save with ckpt
saver.save(sess, "/****1.ckpt")
# .Read the learning result saved by ckpt and restart from there
saver.restore(sess, "/****1.ckpt")
# .Save ckpt to the specified directory
from google.colab import files
files.download( "/****1.ckpt.data-00000-of-00001" )
--DCGAN is difficult because the model is complicated and overfitting is likely to occur. The first consideration is to build a simple model with a shallower layer. --Although it does not seem to be directly related to overfitting, pay attention to the above-mentioned "epoch number", "simple image", and "make the subject simpler". ――Is the latent variable also a fairly important parameter? I will investigate more. ――It may have been a difficult article to read because I just wrote down what I was doing. Thank you for reading to the end. DCGAN is fun because the result appears as an image. I will also try to make improvements and changes.
Recommended Posts