J'ai essayé d'exécuter GAN dans Colaboratory

introduction

Dans l'environnement Jupyter Notebook appelé Colaboratory, qui m'intéressait depuis longtemps, j'ai juste essayé d'exécuter GAN, ce qui m'intéressait depuis longtemps.

Concernant Colavoratory, l'article [Use free GPU at speed per second] Deep learning practice Tips on Colaboratory a été utile.

Pour le GAN, j'ai jeté un coup d'œil rapide à Articles Generative Adversarial Networks. GAN est une sorte de méthode qui se rapproche de la distribution de probabilité des données à portée de main (considérée comme une distribution uniforme). Lorsque les deux réseaux G et D sont bien formés, la distribution de probabilité des données générées par G est à portée de main. Il semble correspondre à la distribution de probabilité des données. Il n'est peut-être pas possible d'apprendre d'une bonne manière, c'est pourquoi nous recherchons actuellement des moyens de le faire.

Courir

Pour le code GAN, je me suis référé à ici. GAN a été simplement codé à l'aide de keras et ce fut une expérience d'apprentissage.

Deux MLP sont définis (c'est G et D), la sortie de G est donnée à D et Adam les entraîne en alternance. D apprend à faire la distinction entre «données disponibles» et «sortie de G». G s'entraîne en manipulant les données de l'enseignant de sorte que le résultat de la discrimination de D devienne «les données disponibles». Pour le moment, D n'est pas formé. Les données d'entraînement sont MNIST.

Si vous pensez qu'il existe un ReLU que vous n'êtes pas familier, il semble qu'il s'appelle Leaky ReLU, qui est souvent utilisé de nos jours. (Référence: À propos de la fonction d'activation ReLU et du clan ReLU [informations supplémentaires]) Contrairement à ReLU, même si l'entrée x de la fonction d'activation est égale ou inférieure à 0, x * α La valeur de est sortie. On dit que ce n'est pas efficace pour le wiki, mais est-il possible de réduire au maximum la disparition du gradient? Je ne suis pas sûr.

Le code a fonctionné sans aucun problème, mais il ne renvoie pas d'historique car je m'entraîne en exécutant train_on_batch dans ma propre boucle au lieu de fit. Je veux visualiser la perte et l'acc, donc je vais ajouter du code pour l'enregistrer en tant que variable d'instance et code pour la visualisation.

 # save
      self.all_d_loss_real.append(d_loss_real)
      self.all_d_loss_fake.append(d_loss_fake)
      self.all_g_loss.append(g_loss)
      
      if epoch % sample_interval == 0:
        self.sample_images(epoch)
        np.save('d_loss_real.npy', self.all_d_loss_real)
        np.save('d_loss_fake.npy', self.all_d_loss_fake)
        np.save('g_loss.npy', self.all_g_loss)

réel est la perte D des données disponibles, et fausse est la perte D des données générées par G. Code à sauvegarder localement.

from google.colab import files
import os

file_list = os.listdir("images")

for file in file_list:
    files.download("images"+os.sep+file)

files.download('d_loss_real.npy')
files.download('d_loss_fake.npy')
files.download('g_loss.npy')

C'est un code pour tracer la perte, etc.

import numpy as np
import pylab as plt


t1 = np.load('d_loss_real.npy')
t2 = np.reshape(np.load('d_loss_fake.npy'),[np.shape(t1)[0],2])
g_loss = np.load('g_loss.npy')

t = (t1+t2)/2
d_loss = t[:,0]
acc = t[:,1]
d_loss_real = t1[:,0]
d_loss_fake = t2[:,0]
acc_real = t1[:,1]
acc_fake = t2[:,1]


n_epoch = 29801

x = np.linspace(1,n_epoch,n_epoch)
plt.plot(x, acc, label='acc')
plt.plot(x, d_loss, label='d_loss')
plt.plot(x, g_loss, label='g_loss')
plt.plot(x, d_loss_real, label='d_loss_real')
plt.plot(x, d_loss_fake, label='d_loss_fake')
plt.legend()
plt.ylim([0, 2])
plt.grid()
plt.show()

#moyenne mobile
num=100#Nombre de moyennes mobiles
b=np.ones(num)/num
acc2=np.convolve(acc, b, mode='same')
d_loss2=np.convolve(d_loss, b, mode='same')
d_loss_real2=np.convolve(d_loss_real, b, mode='same')
d_loss_fake2=np.convolve(d_loss_fake, b, mode='same')
g_loss2=np.convolve(g_loss, b, mode='same')

x = np.linspace(1,n_epoch,n_epoch)
plt.plot(x, acc2, label='acc')
plt.plot(x, d_loss2, label='d_loss')
plt.plot(x, g_loss2, label='g_loss')
plt.plot(x, d_loss_real2, label='d_loss_real')
plt.plot(x, d_loss_fake2, label='d_loss_fake')
plt.legend()
plt.ylim([0,1.2])
plt.grid()
plt.show()

résultat

Image générée de G epoch=0 0.png

epoch=200 200.png

epoch=1000 1000.png

epoch=3000 3000.png

epoch=7000 6600.png

epoch=10000 9800.png

epoch=20000 20000.png

epoch=30000 29800.png

Au fur et à mesure que le nombre d'époques augmente, des images similaires à MNIST seront générées, mais il semble qu'il n'y aura pas de changement particulier à partir de l'époque 7000 environ.

Taux de réponse correct et perte t.png

Moyenne mobile dans la figure ci-dessus (n = 100, rempli de zéros aux deux extrémités) t2.png

Depuis environ 7 000, acc 0,63, d_loss (également réel et faux) 0,63, g_loss 1,02 ~ 1,08 (légère augmentation) (d_loss et g_loss sont des entropies croisées binaires). real est la perte D des données disponibles, fake est la perte D des données générées par G et d_loss est la moyenne.

la perte est définie comme la formule suivante.

\textrm{loss} = -\frac{1}{N}\sum_{n=1}^{N}\bigl( y_n\log{p_n}+(1-y_n)\log{(1-p_n)}\bigr)

N est le nombre de données, y est l'étiquette et p est la valeur de sortie de D (0,1).

C'est déroutant car il contient $ \ log $, mais ce que je fais, c'est la sortie D moyenne $ \ bigl (\ prod_ {n = 1} ^ {N} p_n ^ {y_n} \ bigr) ^ {\ J'ai changé frac {1} {N}} $ en $ \ log $, et $ \ log $ de 0 à 1 est un nombre négatif et difficile à voir, donc j'ai juste ajouté un moins pour en faire un nombre positif.

\begin{align}
\textrm{loss} &= -\frac{1}{N}\sum_{n=1}^{N}\bigl( y_n\log{p_n}+(1-y_n)\log{(1-p_n)}\bigr) \\
&= -\log{\bigl( \prod_{n=1}^{N}p_n^{y_n}\bigr)^{\frac{1}{N}}} -\log{\bigl( \prod_{n=1}^{N}(1-p_n)^{y_n}\bigr)^{\frac{1}{N}}}
\end{align}

◯ Perte vers l'époque 25000

loss Sortie D moyenne
g_loss 1.06 0.35
d_loss 0.63 0.53
d_loss_real 0.63 0.53
d_loss_fake 0.63 0.47

Considération

Plus l'étiquette correspond à la sortie, plus la perte est faible. Le GAN ne vise pas à réduire les pertes, ce n'est donc pas un problème qu'il ne diminue pas.

Si l'apprentissage se passe bien et que les données à portée de main et les données générées par G sont totalement indiscernables (le but de GAN est d'être dans cet état), acc = 0,5 devrait l'être, mais pour autant que le résultat soit vu, il en est ainsi. Ce n'est pas.

En regardant l'image générée par G, il semble que ce ne soit clairement pas un nombre manuscrit, ce qui est probablement la raison pour laquelle acc est élevé. C'est peut-être un peu mieux si vous jouez avec les paramètres, mais comme le but n'est pas de conduire, je vais m'arrêter ici pour le moment.

La valeur de g_loss signifie que plus la valeur de g_loss est faible, plus D détermine que l'image générée par G est vraie, c'est-à-dire plus D est trompé. Inversement, plus la valeur de g_loss est élevée, plus D n'est pas trompé. Si l'objectif est d'avoir une sortie D moyenne de g_loss de 0,5, alors g_loss est de 0,7, donc j'aimerais qu'il baisse un peu plus.

Je ne pense pas qu'il arrive que acc corresponde à d_loss.

À partir de l'époque 7000 ~, il est inquiétant que la quantité de diminution de d_loss_fake soit inférieure à la quantité d'augmentation de g_loss. Même avec la sortie D moyenne, il y a une différence d'environ 10 fois. Puisque l'ordre est D learning → G learning, est-ce efficace pour Moro?

À la fin

J'avais l'impression d'être capable de le faire. Je ne pense pas qu'il y ait quoi que ce soit de particulièrement obstrué, mais comme le collaboratif n'est pas très stable, si vous pensez que le calcul plante au milieu ou que l'écran est rechargé, le cahier précédent sera affiché pour une raison quelconque, et vous ne le remarquerez pas. Je l'ai écrasé et réécrit le code en sanglotant.

2.png

Faites attention si ce pop-up apparaît par le bas après avoir rechargé l'écran. Si vous regardez attentivement le code, il s'agit du code non édité immédiatement après l'ouverture de Colaboratory, et lorsque vous l'enregistrez, il est écrasé par le code modifié.

En guise de contre-mesure, je pense que vous devriez recharger la page. Mon navigateur est Safari, mais lorsque j'appuie sur Ctrl-r pour recharger la page, le code édité est affiché et les variables après exécution sont également conservées. Si ce pop-up apparaît, je pense qu'il est plus sûr de ne pas se précipiter pour l'écraser.

Je pense que vous devez faire des sauvegardes régulières des plantages de calcul.

Recommended Posts

J'ai essayé d'exécuter GAN dans Colaboratory
J'ai essayé d'exécuter pymc
J'ai essayé d'exécuter TensorFlow
J'ai essayé d'exécuter TensorFlow dans l'environnement AWS Lambda: Préparation
J'ai essayé Grumpy (allez exécuter Python).
J'ai essayé d'implémenter Realness GAN
J'ai essayé d'exécuter prolog avec python 3.8.2.
J'ai essayé la notification de ligne en Python
J'ai essayé d'implémenter PLSA en Python
J'ai essayé d'implémenter la permutation en Python
J'ai essayé d'implémenter PLSA dans Python 2
J'ai essayé d'utiliser l'optimisation bayésienne de Python
J'ai essayé de mettre virtualenv dans l'environnement Cygwin
J'ai essayé d'implémenter ADALINE en Python
J'ai essayé d'implémenter PPO en Python
J'ai essayé de gratter
J'ai essayé PyQ
J'ai essayé AutoKeras
J'ai essayé le moulin à papier
J'ai essayé django-slack
J'ai essayé Django
J'ai essayé spleeter
J'ai essayé cgo
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé de jouer à un jeu de frappe avec Python
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé de simuler "Birthday Paradox" avec Python
J'ai essayé la méthode des moindres carrés en Python
J'ai essayé d'exécuter YOLO v3 avec Google Colab
J'ai essayé d'implémenter TOPIC MODEL en Python
J'ai essayé la sortie de caractères "*" de Python dans une autre langue
J'ai essayé d'exécuter faiss avec python, Go, Rust
J'ai essayé le comportement d'E / S Eventlet non bloquant en Python
J'ai essayé d'exécuter python -m summpy.server -h 127.0.0.1 -p 8080
J'ai essayé d'ajouter un module Python 3 en C
J'ai essayé d'exécuter Deep Floor Plan avec Python 3.6.10.
J'ai essayé d'exécuter alembic, un outil de migration pour Python
J'ai essayé d'implémenter le tri sélectif en python
J'ai essayé d'utiliser paramétré
Dessine un graphique avec Julia ... j'ai essayé une petite analyse
J'ai essayé d'utiliser argparse
J'ai essayé de représenter graphiquement les packages installés en Python
J'ai essayé d'utiliser la mimesis
J'ai essayé d'utiliser anytree
J'ai essayé le spoofing ARP
J'ai essayé d'utiliser google test et CMake en C
J'ai essayé d'exécuter l'application sur la plateforme IoT "Rimotte"
J'ai essayé d'utiliser aiomysql
J'ai essayé d'utiliser TradeWave (commerce du système BitCoin en Python)
J'ai essayé d'utiliser Summpy
J'ai essayé Python> autopep8
J'ai essayé d'exécuter python à partir d'un fichier chauve-souris
J'ai essayé d'utiliser coturn
J'ai essayé d'utiliser Pipenv
J'ai essayé d'utiliser matplotlib
J'ai essayé d'utiliser "Anvil".
J'ai essayé d'utiliser Hubot
J'ai essayé d'implémenter le poker de Drakue en Python
J'ai essayé PyCaret2.0 (pycaret-nightly)