Enregistrez la sortie du GAN conditionnel pour chaque classe ~ Avec l'implémentation cGAN par PyTorch ~

Ceci est le deuxième article. Dernière fois a implémenté DCGAN avec PyTorch et a rendu possible la sauvegarde des images de sortie une par une.

Cette fois, nous allons implémenter un GAN conditionnel amélioré (GAN conditionnel) afin que la sortie du GAN puisse être contrôlée. Dans le même temps, comme la dernière fois, nous pourrons enregistrer les images de sortie une par une.

Objectif

Implémentez le GAN conditionnel et enregistrez la sortie une par une

conditional GAN Le GAN conditionnel vous permet de séparer explicitement les images générées. Cela a été rendu possible grâce à la formation utilisant les informations d'étiquette des données des enseignants pendant la formation. Le document est ici

Extrait du papier suivant 180B6B55-C45F-40F9-8863-D5A7B5E1D19D.png C'est comme apprendre en ajoutant des informations d'étiquette de classe aux entrées Generator et Discriminator. Il semble que le format d'entrée change un peu, mais la structure de base du GAN ne change pas.

la mise en oeuvre

Passons à la mise en œuvre. Cette fois, nous allons implémenter le GAN conditionnel basé sur DCGAN qui a été implémenté la dernière fois.

Environnement d'exécution

Google Colaboratory

Importation du module et enregistrement des paramètres de destination

D'abord depuis l'importation du module

import argparse
import os
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets

import torch.nn as nn
import torch.nn.functional as F
import torch

img_save_path = 'images-C_dcgan'
os.makedirs(img_save_path, exist_ok=True)

Ligne de commande et paramètre de valeur par défaut

C'est presque la même chose que la dernière fois. Le changement subtil est que la taille de l'image générée est de 32x32 au lieu de la valeur par défaut MNIST de 28x28.

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--beta1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--beta2', type=float, default=0.999, help='adam: decay of second order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')
args = parser.parse_args()
#arguments pour google colab=parser.parse_args(args=[])
print(args)

C,H,W = args.channels, args.img_size, args.img_size

Réglage du poids

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal(m.weight, 1.0, 0.02)
        torch.nn.init.constant(m.bias, 0.0)

Generator Définissons le générateur. Générer avec chat Combinez les informations d'image et les informations d'étiquette à générer.

class Generator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(Generator, self).__init__()
        self.deconv1_1 = nn.ConvTranspose2d(100, d*2, 4, 1, 0)
        self.deconv1_1_bn = nn.BatchNorm2d(d*2)
        self.deconv1_2 = nn.ConvTranspose2d(10, d*2, 4, 1, 0)
        self.deconv1_2_bn = nn.BatchNorm2d(d*2)
        self.deconv2 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*2)
        self.deconv3 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d)
        self.deconv4 = nn.ConvTranspose2d(d, C, 4, 2, 1)


    # forward method
    def forward(self, input, label):
        x = F.relu(self.deconv1_1_bn(self.deconv1_1(input)))
        y = F.relu(self.deconv1_2_bn(self.deconv1_2(label)))
        x = torch.cat([x, y], 1)
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = torch.tanh(self.deconv4(x))
        return x

La dernière fois, j'ai implémenté Generator avec Upsampling + Conv2d. Cette fois, nous implémentons en utilisant ConvTranspose2d au lieu de la méthode précédente. Cette différence est résumée dans cet article, alors jetez un œil si vous êtes intéressé.

Discriminator La définition de discriminateur. Les informations sur l'étiquette sont également jointes avec le chat ici.


class Discriminator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(Discriminator, self).__init__()
        self.conv1_1 = nn.Conv2d(C, d//2, 4, 2, 1)
        self.conv1_2 = nn.Conv2d(10, d//2, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d * 4, 1, 4, 1, 0)

    def forward(self, input, label):
        x = F.leaky_relu(self.conv1_1(input), 0.2)
        y = F.leaky_relu(self.conv1_2(label), 0.2)
        x = torch.cat([x, y], 1)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.sigmoid(self.conv4(x))
        return x

Fonction de perte et paramètres réseau

Définissez la fonction de perte, initialisez le poids, initialisez le générateur / discriminateur et réglez l'optimiseur.


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize Generator and discriminator
generator = Generator()
discriminator = Discriminator()

if torch.cuda.is_available():
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

Créer un chargeur de données

Nous allons créer un Dataloader. Cette fois, l'image est générée avec une taille de 32 * 32, de sorte que l'image MNIST est redimensionnée dans la partie de prétraitement de l'image.


# Configure data loader
os.makedirs('./data', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Resize(args.img_size),
                       transforms.ToTensor(),
                       transforms.Normalize([0.5,], [0.5,])
                   ])),
    batch_size=args.batch_size, shuffle=True, drop_last=True)
print('the data is ok')

Training Formation du GAN.


for epoch in range(1, args.n_epochs+1):
    for i, (imgs, labels) in enumerate(dataloader):

        Batch_Size = args.batch_size
        N_Class = args.n_classes
        img_size = args.img_size
        # Adversarial ground truths
        valid = torch.ones(Batch_Size).cuda()
        fake = torch.zeros(Batch_Size).cuda()

        # Configure input
        real_imgs = imgs.type(torch.FloatTensor).cuda()

        real_y = torch.zeros(Batch_Size, N_Class)
        real_y = real_y.scatter_(1, labels.view(Batch_Size, 1), 1).view(Batch_Size, N_Class, 1, 1).contiguous()
        real_y = real_y.expand(-1, -1, img_size, img_size).cuda()

        # Sample noise and labels as generator input
        noise = torch.randn((Batch_Size, args.latent_dim,1,1)).cuda()
        gen_labels = (torch.rand(Batch_Size, 1) * N_Class).type(torch.LongTensor)
        gen_y = torch.zeros(Batch_Size, N_Class)
        gen_y = gen_y.scatter_(1, gen_labels.view(Batch_Size, 1), 1).view(Batch_Size, N_Class,1,1).cuda()
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        # Loss for real images
        d_real_loss = adversarial_loss(discriminator(real_imgs, real_y).squeeze(), valid)
        # Loss for fake images
        gen_imgs = generator(noise, gen_y)
        gen_y_for_D = gen_y.view(Batch_Size, N_Class, 1, 1).contiguous().expand(-1, -1, img_size, img_size)

        d_fake_loss = adversarial_loss(discriminator(gen_imgs.detach(),gen_y_for_D).squeeze(), fake)
        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss)
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        g_loss = adversarial_loss(discriminator(gen_imgs,gen_y_for_D).squeeze(), valid)
        g_loss.backward()
        optimizer_G.step()


        print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, args.n_epochs, i, len(dataloader),
                                                            d_loss.data.cpu(), g_loss.data.cpu()))

        batches_done = epoch * len(dataloader) + i
        if epoch % 20 == 0:
            noise = torch.FloatTensor(np.random.normal(0, 1, (N_Class**2, args.latent_dim,1,1))).cuda()
            #fixed labels
            y_ = torch.LongTensor(np.array([num for num in range(N_Class)])).view(N_Class,1).expand(-1,N_Class).contiguous()
            y_fixed = torch.zeros(N_Class**2, N_Class)
            y_fixed = y_fixed.scatter_(1,y_.view(N_Class**2,1),1).view(N_Class**2, N_Class,1,1).cuda()

            with torch.no_grad():
                gen_imgs = generator(noise, y_fixed).view(-1,C,H,W)

            save_image(gen_imgs.data, img_save_path + '/epoch:%d.png' % epoch, nrow=N_Class, normalize=True) 

Résultat d'exécution

Le résultat de l'exécution est le suivant. 20-19600.png Vous pouvez voir que les images générées sont soigneusement organisées pour chaque classe. Le GAN conditionnel vous permet de contrôler les images générées de cette manière.

Générez et enregistrez des images pour chaque classe

Comme la dernière fois, nous pourrons enregistrer les images une par une.


if epoch % 20 == 0:
    noise = torch.FloatTensor(np.random.normal(0, 1, (N_Class**2, args.latent_dim,1,1))).cuda()
    #fixed labels
    y_ = torch.LongTensor(np.array([num for num in range(N_Class)])).view(N_Class,1).expand(-1,N_Class).contiguous()
    y_fixed = torch.zeros(N_Class**2, N_Class)
    y_fixed = y_fixed.scatter_(1,y_.view(N_Class**2,1),1).view(N_Class**2, N_Class,1,1).cuda()

    with torch.no_grad():
        gen_imgs = generator(noise, y_fixed).view(-1,C,H,W)

    save_image(gen_imgs.data, img_save_path + '/epoch:%d.png' % epoch, nrow=N_Class, normalize=True)

Ici partie


if epoch % 20 == 0:
    for l in range(10): #Conservez 10 feuilles pour chaque classe
        noise = torch.FloatTensor(np.random.normal(0, 1, (N_Class**2, args.latent_dim,1,1))).cuda()
        #fixed labels
        y_ = torch.LongTensor(np.array([num for num in range(N_Class)])).view(N_Class,1).expand(-1,N_Class).contiguous()
        y_fixed = torch.zeros(N_Class**2, N_Class)
        y_fixed = y_fixed.scatter_(1,y_.view(N_Class**2,1),1).view(N_Class**2, N_Class,1,1).cuda()

        for m in range()
            with torch.no_grad():
                gen_imgs = generator(noise, y_fixed).view(-1,C,H,W)

            save_gen_imgs = gen_imgs[10*i]
            save_image(save_gen_imgs, img_save_path + '/epochs:%d/%d/epoch:%d-%d_%d.png' % (epoch, i, epoch,i, j), normalize=True)

Changez-le comme ça. Si vous souhaitez faire cela, vous devez modifier la structure des répertoires pour enregistrer les images.

images-C_dcgan
├── epochs:20
│   ├── 0
│   ├── 1
│   ├── 2
│   ├── 3
│   ├── 4
│   ├── 5
│   ├── 6
│   ├── 7
│   ├── 8
│   └── 9
│     .
│     .
│     .
│
└── epochs:200
    ├── 0
    ├── 1
    ├── 2
    ├── 3
    ├── 4
    ├── 5
    ├── 6
    ├── 7
    ├── 8
    └── 9

Il y a de 0 à 9 répertoires pour 20 époques. Il est plus facile de créer à la fois en utilisant ʻos.makedirs`. Les images sont maintenant enregistrées pour chaque classe.

Sommaire

Cette fois, nous avons implémenté le GAN conditionnel suivant DCGAN afin que les images générées puissent être sauvegardées une par une. Cette fois, nous avons implémenté le GAN conditionnel, qui est le plus simple, en ajoutant des informations d'étiquette aux entrées Générateur et Discriminateur. Actuellement, les normes de facto pour la mise en œuvre du GAN conditionnel sont des technologies telles que le discriminateur de projection et la normalisation conditionnelle des lots. Je ne comprends pas grand-chose à la technologie ici, alors si j'ai une chance, j'aimerais étudier tout en la mettant en œuvre.

Recommended Posts

Enregistrez la sortie du GAN conditionnel pour chaque classe ~ Avec l'implémentation cGAN par PyTorch ~
Sauvegardez la sortie de GAN une par une ~ Avec l'implémentation de GAN par PyTorch ~
Ajouter des attributs d'objets de classe avec une instruction for
Génération du caractère indésirable MNIST (KMNIST) avec cGAN (GAN conditionnel)
Résumé de l'implémentation de base par PyTorch
Sortie csv avec un nombre différent de chiffres pour chaque colonne avec numpy
La troisième nuit de la boucle avec pour
La deuxième nuit de la boucle avec pour
Construisez un serveur API pour vérifier le fonctionnement de l'implémentation frontale avec python3 et Flask
Sortie de la table spécifiée de la base de données Oracle en Python vers Excel pour chaque fichier
Afficher progressivement la sortie de la commande exécutée par le sous-processus.
Créez Fatjar en changeant la classe principale avec Gradle
Calculez la valeur totale de plusieurs colonnes avec awk
Implémentation python de la classe de régression linéaire bayésienne
Préparation de l'environnement d'exécution de PyTorch avec Docker Novembre 2019
Attribuer une date au nom du PDF décomposé pour chaque page
[Pour les débutants] Quantifier la similitude des phrases avec TF-IDF