Je voulais faire évoluer cGAN vers ACGAN

introduction

Cet article est une suite de «[CGAN (GAN conditionnel) Génère MNIST (KMNIST)» (https://qiita.com/kyamada101/items/5195b1b32c60168b1b2f). C'est un record quand on essaie de faire ACGAN basé sur cGAN.

Je pense que c'est une idée naturelle en termes d'évolution depuis cGAN, mais quand je la mets en œuvre, c'est assez ...

Qu'est-ce que ACGAN

Depuis que j'ai brièvement présenté cGAN dans l'article précédent, je vais expliquer brièvement ACGAN. ACGAN est, en un mot, ** "cGAN où le discriminateur effectue également des tâches de classification" **. C'est une méthode qui permet la sortie d'images avec plus de variations.

L'article original est [ici](Synthèse d'image conditionnelle avec les GAN de classificateur auxiliaire)

A. Odena, C. Olah, J. Shlens. Conditional Image Synthesis With Auxiliary Classifier GANs. CVPR, 2016

En ce qui concerne les articles d'ACGAN, certaines personnes ont publié les articles originaux, ce sera donc utile.
Article de référence
Explication des articles sur AC-GAN (synthèse d'image conditionnelle avec les GAN de classificateur auxiliaire)

Structure du modèle ACGAN

Dans cGAN, l'image authentique / fausse et les informations d'étiquette étaient entrées dans Discriminator, et l'identification de l'authentique ou du faux était sortie. D'un autre côté, dans ACGAN, l'entrée de Discriminator n'est qu'une image, et pas seulement l'identification du vrai ou du faux mais aussi le jugement de classe pour deviner quelle classe il est ajouté à la sortie. Il ressemble à ce qui suit lorsqu'il est écrit dans un diagramme. ACGAN.jpg La partie «classe» de la figure est la sortie de la classification prédite par Discriminator. Comme «label», il se présente sous la forme d'un vecteur de dimension de numéro de classe.

Mis en œuvre pour le moment

ACGAN a l'implémentation PyTorch sur GitHub. Avec cela comme référence, modifions l'implémentation de cGAN que j'ai écrit dans l'article précédent.

Que faire

Est presque tout. Ensuite, la structure de Discriminator ressemble à ceci. ACGAN_discriminator.jpg Il s'agit d'un dessin du diagramme de structure du discriminateur cGAN publié dans l'article précédent, mais la partie représentée en rouge est le changement d'ACGAN.

Une implémentation de Discriminator.

python



class Discriminator(nn.Module):
    def __init__(self, num_class):
        super(Discriminator, self).__init__()
        self.num_class = num_class
        
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), #L'entrée est 1 canal(Parce que c'est noir et blanc),Nombre de filtres 64,Taille de filtre 4*4
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(128),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
        )
        
        self.fc_TF = nn.Sequential(
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )
        
        self.fc_class = nn.Sequential(
            nn.Linear(1024, num_class),
            nn.LogSoftmax(dim=1),
        )
        
        self.init_weights()
        
    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.Linear):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm1d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
        
    def forward(self, img):
        x = self.conv(img)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)
        x_TF = self.fc_TF(x)
        x_class = self.fc_class(x)
        return x_TF, x_class

Il semble y avoir différentes manières d'ajouter le résultat de la classification. Dans l'implémentation PyTorch du lien que j'ai posté plus tôt, la couche Linear était bifurquée à la fin, donc je l'implémente de la même manière ici.

Selon ce changement, la fonction par époque ressemble à ceci.

python



def train_func(D_model, G_model, batch_size, z_dim, num_class, TF_criterion, class_criterion,
               D_optimizer, G_optimizer, data_loader, device):
    #Mode entraînement
    D_model.train()
    G_model.train()

    #La vraie étiquette est 1
    y_real = torch.ones((batch_size, 1)).to(device)
    D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Etiquette de bruit à mettre en D

    #La fausse étiquette est 0
    y_fake = torch.zeros((batch_size, 1)).to(device)
    D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Etiquette de bruit à mettre en D
    
    #Initialisation de la perte
    D_running_TF_loss = 0
    G_running_TF_loss = 0
    D_running_class_loss = 0
    D_running_real_class_loss = 0
    D_running_fake_class_loss = 0
    G_running_class_loss = 0
    
    #Calcul lot par lot
    for batch_idx, (data, labels) in enumerate(data_loader):
        #Ignorer si inférieur à la taille du lot
        if data.size()[0] != batch_size:
            break
        
        #Création de bruit
        z = torch.normal(mean = 0.5, std = 1, size = (batch_size, z_dim)) #Moyenne 0.Générer des nombres aléatoires selon une distribution normale de 5
        
        real_img, label, z = data.to(device), labels.to(device), z.to(device)
        
        #Mise à jour du discriminateur
        D_optimizer.zero_grad()
        
        #Mettre une image réelle dans Discriminator et propager vers l'avant ⇒ Calcul des pertes
        D_real_TF, D_real_class = D_model(real_img)
        D_real_TF_loss = TF_criterion(D_real_TF, D_y_real)
        CEE_label = torch.max(label, 1)[1].to(device)
        D_real_class_loss = class_criterion(D_real_class, CEE_label)
        
        #Mettre l'image créée en mettant du bruit dans Generator dans Discriminator et propager vers l'avant ⇒ Calcul de la perte
        fake_img = G_model(z, label)
        D_fake_TF, D_fake_class = D_model(fake_img.detach()) #fake_Stop Loss calculé dans les images pour qu'il ne se propage pas vers Generator
        D_fake_TF_loss = TF_criterion(D_fake_TF, D_y_fake)
        D_fake_class_loss = class_criterion(D_fake_class, CEE_label)

        #Minimiser la somme de deux pertes
        D_TF_loss = D_real_TF_loss + D_fake_TF_loss
        D_class_loss = D_real_class_loss + D_fake_class_loss
        
        D_TF_loss.backward(retain_graph=True)
        D_class_loss.backward()
        D_optimizer.step()
        
        D_running_TF_loss += D_TF_loss.item()
        D_running_class_loss += D_class_loss.item()
        D_running_real_class_loss += D_real_class_loss.item()
        D_running_fake_class_loss += D_fake_class_loss.item()


        #Mise à jour du générateur
        G_optimizer.zero_grad()
        
        #L'image créée en mettant du bruit dans le Générateur est placée dans le Discriminateur et propagée vers l'avant ⇒ La partie détectée devient Perte
        fake_img_2 = G_model(z, label)
        D_fake_TF_2, D_fake_class_2 = D_model(fake_img_2)
        
        #G perte(max(log D)Optimisé avec)
        G_TF_loss = -TF_criterion(D_fake_TF_2, y_fake)
        G_class_loss = class_criterion(D_fake_class_2, CEE_label) #Du point de vue de G, ce serait bien s'il pensait que D était réel et lui donnait un cours.
        
        G_TF_loss.backward(retain_graph=True)
        G_class_loss.backward()
        G_optimizer.step()
        G_running_TF_loss += G_TF_loss.item()
        G_running_class_loss -= G_class_loss.item()
        
    D_running_TF_loss /= len(data_loader)
    D_running_class_loss /= len(data_loader)
    D_running_real_class_loss /= len(data_loader)
    D_running_fake_class_loss /= len(data_loader)

    G_running_TF_loss /= len(data_loader)
    G_running_class_loss /= len(data_loader)
    
    return D_running_TF_loss, G_running_TF_loss, D_running_class_loss, G_running_class_loss, D_running_real_class_loss, D_running_fake_class_loss

En plus des changements mentionnés précédemment, j'ai également modifié le bruit à ajouter. La dernière fois, il s'agissait d'une distribution normale avec 30 dimensions, moyenne 0,5 et écart type 0,2, mais cette fois, il s'agit d'une distribution normale avec 100 dimensions, moyenne 0,5 et écart type 1.

La perte de classification est torch.nn.NLLLoss (). Cela correspondait également à la mise en œuvre du lien plus tôt.

résultat

Le premier est le graphique des pertes. Dans ACGAN, il existe deux types de perte, la perte d'identification réelle ou fausse et la perte de classification, et les deux pertes sont propagées à la fois au générateur et au discriminateur. Il est également tracé séparément dans le graphique. simpleACGAN loss.png

«T / F_loss» est la perte (ligne continue) pour une identification authentique / fausse, et «class_loss» est la perte (ligne pointillée) pour la classification.

En regardant cela, il semble que cela fonctionne. Pourtant... result_ACGAN (1).gif Ceci est un gif lorsqu'une image de chaque étiquette est générée pour chaque époque. J'ai entré les informations d'étiquette de sorte que la ligne supérieure soit "Ah, I, U ..." à partir de la gauche, et la partie inférieure droite est "..., N, ゝ". Il n'y a pratiquement aucune correspondance entre l'étiquette et l'image générée. Mais on dirait qu'il produit "du texte sur une autre étiquette" plutôt qu'une image complètement dénuée de sens.

Semblable à cGAN, j'ai essayé de générer 5 "A" à "ゝ" par Generator après un entraînement de 100 points. many_.png N'est-ce pas seulement "ke" qui semble correspondre à Label-chan? (Au contraire, le mode s'effondre complètement ...)

D'ailleurs, c'est le résultat de la génération de cGAN après 100 périodes d'entraînement dans les mêmes conditions. epoch_00100.png De toute évidence, cGAN génère des caractères plus proches de l'étiquette.

Pourquoi ça ne marche pas ...?

À première vue sur la sortie, chez ACGAN, Generator et Discriminator ** pensent que les caractères avec une forme différente sont les caractères de cette étiquette ** (Ex: Discriminator et Generator sont tous les deux "I" N'est-ce pas (celui qui ressemble à la forme est traité comme l'étiquette «A»)? J'ai pensé.

simpleACGAN_discriminatorloss.png Il s'agit d'un graphique qui divise la perte de la classification du discriminateur en la perte dérivée de l'image réelle et la perte dérivée de la fausse image (= image créée par le générateur). sum_class_loss est la valeur totale (= identique à la ligne pointillée rouge dans le graphique précédent). En regardant ce graphique, Discriminator commet une erreur en jugeant l'image réelle (en particulier dans les premiers stades de l'apprentissage) et devine le jugement de la fausse image. (En termes numériques, real_class_loss est environ 20 fois la valeur au début de fake_class_loss et environ 5 fois à la fin)

En d'autres termes, ** l'image créée par le Générateur avec l'étiquette «A» est traitée comme «A» par Discriminator même si la forme réelle est assez différente de «A» **. Je peux imaginer ça.

Idéalement peut-être, la perte de la classification devrait être à peu près la même pour les images réelles et fausses.

Comparez avec ce qui semble fonctionner

Comme mentionné dans l'article original d'ACGAN, il semble que ** s'il y a trop de classes, la qualité de l'image de sortie se détériorera sur le même réseau **. Dans l'article original, ImageNet (1000 classes) est divisé en 10 classes x 100 cas pour l'expérimentation.

Par conséquent, j'ai décidé d'essayer ceci dans 5 classes une fois. Faisons la même structure de réseau et essayons de générer 5 caractères de "A" à "O". ACGAN_5class.png Le graphique des pertes est similaire. Il semble qu'il y ait encore de la place pour que le T / F_loss diminue. result_tmp.gif Il y a aussi des inégalités ici, mais la seconde moitié est assez belle. Ensuite, générons 5 images chacune après une formation de 100 époques. many_ (1).png Il semble que le mode ne s'effondre pas.

Ensuite, c'est la perte de la classification de Discriminateur. ACGAN_Discriminator_5class.png Sur une base numérique, il y avait une différence d'environ 10 fois au début, mais c'est presque la même valeur dans l'étape finale, mais il est difficile de voir ce graphique, donc je ne l'afficherai qu'après 3 époques. ACGAN_Discriminator_5class_3epoch~.png Si vous regardez ceci, vous pouvez voir que real_class_loss et fake_class_loss deviennent des valeurs assez proches.

perte par iter

En premier lieu, y a-t-il une différence de 10 à 20 fois entre la véritable classification et la fausse classification de la 1ère époque dans les premiers stades de l'apprentissage? ?? J'ai pensé, alors j'ai essayé d'afficher la perte pour chaque iter (pour chaque mini-lot). ACGAN_discriminator_loss_per_iter.png Il est vrai que la valeur de la perte ne change pas entre «real_class_loss» et «fake_class_loss» au début, mais vous pouvez voir que «fake_class_loss» baisse fortement.

Essayez le pré-apprentissage

J'ai essayé de n'entraîner que l'image réelle dans les premières époques, mais cela n'avait pas beaucoup de sens, alors j'ai décidé de ne pré-apprendre que la tâche de classification.

Obtenez uniquement le discriminateur et ne résolvez que la tâche de classification.

Résultats de la tâche de classification

ACGAN_discriminate_loss.png ACGAN_discriminating_acc.png La convergence est assez rapide, donc je ne fais que 20 époques. En conséquence, il est subtil, mais pour le moment, j'utiliserai ce Discriminateur après une formation de 20 époques.

Résultats lors de l'application du pré-apprentissage

pretrianed_ACGAN.png La perte vraie / fausse est presque la même que sans pré-apprentissage. Perte de classification Quant à elle, elle est devenue assez petite depuis le début.

Maintenant, regardons la perte de classification dérivée de l'image réelle et de la fausse image. preteained_ACGAN_real_fake.png J'ai essayé d'apprendre jusqu'à 300 époques. Par rapport au non pré-appris, la valeur de perte dérivée de l'image réelle est également considérablement inférieure. C'est environ quatre fois plus que la perte dérivée de la fausse image, mais ce n'est toujours pas la même valeur.

Jetons un coup d'oeil à l'image générée par ACGAN après cette formation de 300 époques. pretrained_ACGAN_300epoch.png Hmm. .. Aucun effet n'est observé. Il n'y a pas d'augmentation du nombre de caractères réussis et une réduction du mode se produit.

Impressions

Il existe plusieurs ensembles de données kuzuji qui ont un grand nombre de données par caractère, 6000 et seulement environ 300 à 400. Je pense que plus le nombre de données par classe est grand, mieux c'est, alors j'ai pensé que cela pourrait fonctionner si le nombre de données était supérieur à CIFAR-10, mais ce n'était pas bon.

Personnellement, la distance entre les caractères de chaque étiquette dans l'espace latent n'est-elle pas proche (= les caractères avec des étiquettes différentes sont assez proches dans l'espace latent)? Je pense. Dans l'expérience de l'article original, j'ai expérimenté CIFAR-10 et ImageNet toutes les 10 classes, mais dans le cas des caractères indésirables, il n'y avait qu'un peu plus de la moitié des personnages qui travaillaient dans 10 classes, et cela ne fonctionnait que dans 5 classes.

En tout cas, il semble assez difficile de viser et de sortir la classe 49 avec ACGAN, donc je vais abandonner ...

Recommended Posts

Je voulais faire évoluer cGAN vers ACGAN
Chaîne de hachage que je voulais éviter (2)
Chaîne de hachage que je voulais éviter (1)
Je voulais résoudre ABC160 avec Python
Je voulais résoudre ABC159 avec Python
Je voulais résoudre ABC172 avec Python
Je voulais vraiment copier avec du sélénium
Implémentation de DQN avec TensorFlow (je voulais ...)
Scraping de pages i-Town: je voulais prendre la place de Wise-kun
Je voulais jouer avec la courbe de Bézier
Je voulais installer Python 3.4.3 avec Homebrew + pyenv
Je voulais juste comprendre le module Pickle de Python
Je voulais aussi vérifier les indices de type avec numpy
Je voulais utiliser la bibliothèque Python de MATLAB
J'ai essayé de déboguer.
[Échec] Je voulais générer des phrases en utilisant TextRegressor de Flair
Une histoire sur la volonté de modifier un peu le site d'administration de Django
Je voulais résoudre le concours de programmation Panasonic 2020 avec Python
Je voulais ignorer certaines extensions lors de la création de la documentation Sphinx
Je voulais m'inquiéter du temps d'exécution et de l'utilisation de la mémoire
J'ai capturé le projet Toho avec Deep Learning ... je le voulais.
Je voulais calculer un tableau avec la méthode des subs de Sympy
Je voulais supprimer plusieurs objets en s3 avec boto3
J'ai essayé d'apprendre PredNet
J'ai essayé d'organiser SVM.
J'ai parlé à Raspberry Pi
J'ai essayé d'implémenter PCANet
Introduction à l'optimisation non linéaire (I)
J'ai appliqué LightFM à Movielens
Je veux résoudre SUDOKU
J'ai essayé de réintroduire Linux
J'ai essayé de présenter Pylint
J'ai essayé de résumer SparseMatrix
jupyter je l'ai touché
J'ai essayé d'implémenter StarGAN (1)
Je voulais le faire comme exécuter un cas de test pour AtCoder.
Je voulais créer une présentation intelligente avec Jupyter Notebook + nb present
Je voulais convertir ma photo de visage en un style Yuyu.
Je voulais contester la classification du CIFAR-10 en utilisant l'entraîneur de Chainer
Je voulais faire quelque chose comme la pipe d'Elixir en Python
Je voulais résoudre le problème ABC164 A ~ D avec Python
Ce que j'ai fait quand je voulais rendre Python plus rapide -Édition Numba-
[Django] Je voulais tester lors du POST d'un fichier volumineux [TDD]