Génération du caractère indésirable MNIST (KMNIST) avec cGAN (GAN conditionnel)

introduction

J'ai essayé de générer un caractère indésirable MNIST en utilisant une sorte de GAN, cGAN (GAN conditionnel). Pour des aspects théoriques détaillés, veuillez vous référer aux liens qui seront utiles le cas échéant.

--Je veux implémenter GAN avec PyTorch ――Je souhaite créer un modèle capable de générer l'image souhaitée

J'espère que cela sera utile pour ceux qui aiment.

Qu'est-ce que Kuzuji MNIST (KMNIST)?

À propos de l'ensemble de données utilisé cette fois. KMNIST est un ensemble de données créé pour l'apprentissage automatique en tant que dérivé du "Japanese Classics Kuzuji Data Set" créé par le Humanities Open Data Sharing Center. Vous pouvez le télécharger depuis GitHub Link. kmnist.png

"KMNIST Dataset" (créé par CODH) "Japanese Classics Kuzuji Dataset" (Kokubunken et al.) Adapté doi: 10.20676 / 00000341

Comme ce MNIST (numéro manuscrit) que connaît tous ceux qui ont fait du machine learning, une image mesure 1 x 28px x 28px.

Les trois types d'ensembles de données suivants peuvent être téléchargés à partir du référentiel au format compressé numpy.array.

--kuzushiji-MNIST (10 caractères de hiragana) --kuzushiji-49 (49 caractères de hiragana) --kuzushiji-kanji (3832 kanji)

Parmi ceux-ci, "kuzushiji-49" sera utilisé cette fois. Il n'y a pas de raison profonde particulière, mais si 49 caractères hiragana peuvent être ciblés et générés, est-il possible de générer des phrases manuscrites? Je pensais que c'était une motivation légère.

Qu'est-ce que le GAN

Abordons brièvement GAN avant cGAN. GAN est une abréviation de "Generative Adversarial Network" (= réseau de génération hostile) et est une sorte de modèle de génération d'apprentissage en profondeur. C'est particulièrement efficace dans le domaine de la génération d'images, et je pense que le résultat de la génération d'images faciales de personnes qui n'existent pas dans le monde est célèbre.

Structure du modèle GAN

Ce qui suit est un schéma de modèle approximatif du GAN. "G" signifie Generator et "D" signifie Discriminator.

Generator générera une fausse image aussi proche de la réalité que possible à partir du bruit. Le discriminateur fait la distinction entre l'image réelle (real_img) extraite de l'ensemble de données et la fausse image (fake_img) créée par le générateur (vrai ou faux).

En répétant cet apprentissage, le générateur tente de créer une image aussi proche que possible de la chose réelle que le discriminateur ne peut pas détecter, et le discriminateur essaie de détecter le faux créé par le générateur et la chose réelle dérivée de l'ensemble de données, de sorte que le générateur est généré. La précision augmentera.

GAN.jpg


Article de référence
GAN (1) Comprendre la structure de base que je n'entends plus

Les articles liés au GAN sont organisés dans This GitHub Repository.

Qu'est-ce que cGAN (GAN conditionnel)?

Ensuite, je voudrais parler du GAN conditionnel utilisé cette fois. En termes simples, c'est ** "GAN qui peut générer l'image souhaitée" **. L'idée est simple, c'est comme décider de l'image à générer en ajoutant des informations d'étiquette à l'entrée du discriminateur et du générateur.

Le document original est ici

structure du modèle de cGAN

C'est la même chose qu'un GAN normal sauf que ** "Entrez le libellé pendant l'entraînement" **. De plus, bien que les informations d'étiquette soient utilisées, Discriminator détermine uniquement "si l'image est authentique". cGAN.jpg


Article de référence
Implémentation de GAN (6) GAN conditionnel que je n'entends plus

la mise en oeuvre

Passons maintenant à l'implémentation.

environnement

J'ai installé jupyterlab et fonctionne sur Ubuntu 18.04.

Se préparer à l'apprentissage

Importez le module requis

python


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import random

Créer un jeu de données

Télécharger les données

Téléchargez les données au format numpy depuis le github de KMNIST. Avec jupyterlab, après l'ouverture du Terminal et le déplacement du référentiel wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-imgs.npz wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-labels.npz Ensuite, vous pouvez télécharger l'image et l'étiquette. Au fait, s'il s'agit de KMNIST avec 10 caractères de hiragana, il est inclus par défaut dans torchvision. Si cela ne vous dérange pas, tout comme MNIST normal

python


transform = transforms.Compose(
    [transforms.ToTensor(),
     ])
train_data_10 = torchvision.datasets.KMNIST(root='./data', train=True,download=True,transform=transform)

Vous pouvez l'utiliser si vous le faites.

Prétraitement des données

Si vous souhaitez créer votre propre jeu de données personnalisé avec PyTorch, vous devez définir vous-même le prétraitement. Le prétraitement basé sur l'image est principalement contenu dans torchvision.transforms, donc j'utilise souvent ceci, mais vous pouvez aussi créer le vôtre.

python


class Transform(object):
    def __init__(self):
        pass
    
    def __call__(self, sample):
        sample = np.array(sample, dtype = np.float32)
        sample = torch.tensor(sample)
        return (sample/127.5)-1
    
transform = Transform()

La plupart des fractions gérées par numpy sont np.float64 (nombre à virgule flottante 64 bits), mais PyTorch gère la valeur fractionnaire avec le nombre à virgule flottante 32 bits par défaut, donc une erreur se produira si elles ne sont pas alignées.

De plus, le traitement pour normaliser la valeur de luminosité de l'image dans la plage de [-1,1] est effectué ici. En effet, "Tanh" est utilisé dans la dernière couche de la sortie du générateur qui sortira plus tard, de sorte que la valeur de luminosité de l'image réelle sera ajustée en conséquence.

Classe de jeu de données

Ensuite, nous définirons la classe Dataset. Il s'agit d'un module qui renvoie un ensemble de données et d'étiquettes, et retourne les données prétraitées par la transformation définie précédemment lors de la récupération des données.

python


from tqdm import tqdm

class dataset_full(torch.utils.data.Dataset):
    
    def __init__(self, img, label, transform=None):
        self.transform = transform
        self.data_num = len(img)
        self.data = []
        self.label = []
        for i in tqdm(range(self.data_num)):
            self.data.append([img[i]])
            self.label.append(label[i])
        self.data_num = len(self.data)
            
    def __len__(self):
        return self.data_num
    
    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = np.identity(49)[self.label[idx]]
        out_label = np.array(out_label, dtype = np.float32)
        
        if self.transform:
            out_data = self.transform(out_data)
            
        return out_data, out_label

Si vous mettez le premier tqdm, la progression sera affichée comme un graphique à barres lorsque vous activez l'instruction for, mais cela n'a rien à voir avec le cGAN lui-même.

J'utilise np.identity pour créer un vecteur one-hot d'une longueur de 49.

Former un ensemble de données à partir de données DL

Créez un Dataset à l'aide des classes Transform, Dataset implémentées à partir des données que vous avez téléchargées précédemment.

python


path = %pwd
train_img = np.load('{}/k49-train-imgs.npz'.format(path))
train_img = train_img['arr_0']
train_label = np.load('{}/k49-train-labels.npz'.format(path))
train_label = train_label['arr_0']

train_data = dataset_full(train_img, train_label, transform=transform)

Si vous avez mis dans le tqdm plus tôt, la progression sera affichée lorsque vous exécutez cela. La plupart des données sont de 232 625, mais je ne pense pas que cela prendra longtemps.

Créer DataLoader

Nous avons un ensemble de données, mais nous ne récupérons pas les données directement à partir de cet ensemble de données lors de l'entraînement du modèle. Puisque nous entraînons lot par lot, nous définirons un DataLoader qui renverra des données de taille de lot.

python



batch_size = 256

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers=2)

Si vous définissez shuffle = True, les données extraites de DataLoader seront aléatoires. num_workers est un argument qui spécifie le nombre de cœurs de processeur utilisés par DataLoader, et n'est pas particulièrement pertinent pour le cGAN lui-même.

Le Transform-Dataset-DataLoader jusqu'à présent est résumé dans les articles suivants.
Article de référence
Vérifiez le fonctionnement de base des transformations PyTorch / Dataset / DataLoader

Définir le générateur

Je vais fabriquer le corps du modèle. Generator crée une fausse image (fake_img) à partir du bruit et des étiquettes.

La méthode de mise en œuvre est assez différente selon la personne, mais la structure du Générateur créé cette fois-ci est la suivante. (C'est écrit à la main, mais je suis désolé ...) cGAN_G.png Dans l'entrée, z_dim (dimension du bruit) est 30 et num_class (nombre de classes) est 49 caractères hiragana, il est donc défini sur 49. La fausse image de la sortie a la forme de 1 (canal) x 28 (px) x 28 (px).

python



class Generator(nn.Module):
    def __init__(self, z_dim, num_class):
        super(Generator, self).__init__()
        
        self.fc1 = nn.Linear(z_dim, 300)
        self.bn1 = nn.BatchNorm1d(300)
        self.LReLU1 = nn.LeakyReLU(0.2)
        
        self.fc2 = nn.Linear(num_class, 1500)
        self.bn2 = nn.BatchNorm1d(1500)
        self.LReLU2 = nn.LeakyReLU(0.2)
        
        self.fc3 = nn.Linear(1800, 128 * 7 * 7)
        self.bn3 = nn.BatchNorm1d(128 * 7 * 7)
        self.bo1 = nn.Dropout(p=0.5)
        self.LReLU3 = nn.LeakyReLU(0.2)
        
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), #Changez le nombre de canaux de 128 à 64.
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), #Changement du nombre de canaux de 64 à 1
            nn.Tanh(),
        )
        
        self.init_weights()
        
    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.ConvTranspose2d):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.Linear):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm1d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
        
    def forward(self, noise, labels):
        y_1 = self.fc1(noise)
        y_1 = self.bn1(y_1)
        y_1 = self.LReLU1(y_1)
        
        y_2 = self.fc2(labels)
        y_2 = self.bn2(y_2)
        y_2 = self.LReLU2(y_2)
        
        x = torch.cat([y_1, y_2], 1)
        x = self.fc3(x)
        x = self.bo1(x)
        x = self.LReLU3(x)
        x = x.view(-1, 128, 7, 7)
        x = self.deconv(x)
        return x

Définition de discriminateur

Vient ensuite Discriminator. Le discriminateur entre l'image authentique / fausse et ses informations d'étiquette et détermine si elle est authentique ou fausse.

La structure du discriminateur créé cette fois est la suivante. cGAN_D.png ʻImg(image d'entrée) vaut 1 (canal) x 28 (px) x 28 (px) pour les authentiques et les faux, etlabels` (étiquette d'entrée) est un vecteur unidimensionnel à 49 dimensions. La sortie détermine si elle est authentique ou non avec une valeur de 0 à 1.

Concattez l'image et étiquetez les informations dans le sens du canal avec cat au milieu. Je pense que l'article de cGAN que j'ai mentionné plus tôt est facile à comprendre sur ce domaine.

python



class Discriminator(nn.Module):
    def __init__(self, num_class):
        super(Discriminator, self).__init__()
        self.num_class = num_class
        
        self.conv = nn.Sequential(
            nn.Conv2d(num_class + 1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(128),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )
        
        self.init_weights()
        
    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.Linear):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm1d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
        
    def forward(self, img, labels):
        y_2 = labels.view(-1, self.num_class, 1, 1)
        y_2 = y_2.expand(-1, -1, 28, 28)
        
        x = torch.cat([img, y_2], 1)
        
        x = self.conv(x)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)
        return x

Calcul par époque

1 Créez une fonction pour calculer l'époque.

python



def train_func(D_model, G_model, batch_size, z_dim, num_class, criterion, 
               D_optimizer, G_optimizer, data_loader, device):
    #Mode entraînement
    D_model.train()
    G_model.train()

    #La vraie étiquette est 1
    y_real = torch.ones((batch_size, 1)).to(device)
    D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Etiquette de bruit à mettre en D

    #La fausse étiquette est 0
    y_fake = torch.zeros((batch_size, 1)).to(device)
    D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Etiquette de bruit à mettre en D
    
    #Initialisation de la perte
    D_running_loss = 0
    G_running_loss = 0
    
    #Calcul lot par lot
    for batch_idx, (data, labels) in enumerate(data_loader):
        #Ignorer si inférieur à la taille du lot
        if data.size()[0] != batch_size:
            break
        
        #Création de bruit
        z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Moyenne 0.Générer des nombres aléatoires selon une distribution normale de 5
        
        real_img, label, z = data.to(device), labels.to(device), z.to(device)
        
        #Mise à jour du discriminateur
        D_optimizer.zero_grad()
        
        #Mettre une image réelle dans Discriminator et propager vers l'avant ⇒ Calcul des pertes
        D_real = D_model(real_img, label)
        D_real_loss = criterion(D_real, D_y_real)
        
        #Mettre l'image créée en mettant du bruit dans Generator dans Discriminator et propager vers l'avant ⇒ Calcul de la perte
        fake_img = G_model(z, label)
        D_fake = D_model(fake_img.detach(), label) #fake_Stop Loss calculé dans les images pour qu'il ne se propage pas vers Generator
        D_fake_loss = criterion(D_fake, D_y_fake)
        
        #Minimiser la somme de deux pertes
        D_loss = D_real_loss + D_fake_loss
        
        D_loss.backward()
        D_optimizer.step()
                
        D_running_loss += D_loss.item()
        
        #Mise à jour du générateur
        G_optimizer.zero_grad()
        
        #L'image créée en mettant du bruit dans le Générateur est placée dans le Discriminateur et propagée vers l'avant ⇒ La partie détectée devient Perte
        fake_img_2 = G_model(z, label)
        D_fake_2 = D_model(fake_img_2, label)
        
        #G perte(max(log D)Optimisé avec)
        G_loss = -criterion(D_fake_2, y_fake)
        
        G_loss.backward()
        G_optimizer.step()
        G_running_loss += G_loss.item()
        
    D_running_loss /= len(data_loader)
    G_running_loss /= len(data_loader)
    
    return D_running_loss, G_running_loss

Le «critère» qui apparaît dans l'argument est la classe de perte (dans ce cas, l'entropie croisée binaire). Ce que nous faisons avec cette fonction est dans l'ordre

est.

Ingéniosité de mise en œuvre

C'est un peu vieux, mais cette implémentation intègre l'ingéniosité qui apparaît dans "Comment former un GAN" à NIPS2016 pour réussir l'apprentissage du GAN. Lien GitHub


Article de référence
14 Techniques for Learning GAN (Generative Adversarial Networks)

1. Normaliser l'entrée

Quand j'ai créé la classe Dataset

python


return (sample/127.5)-1

Est-ce. De plus, la dernière couche de Generator est nn.Tanh ().

2. Fonction de perte fixe de G

python


#G perte(max(log D)Optimisé avec)
        G_loss = -criterion(D_fake_2, y_fake)

Est-ce. «D_fake_2» est le jugement de Discriminator, et «y_fake» est un vecteur 128 × 1 0.

3.z est de la distribution gaussienne

Échantillonnez le bruit à mettre dans le générateur à partir d'une distribution normale au lieu d'une distribution uniforme.

python


#Création de bruit
z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Moyenne 0.Générer des nombres aléatoires selon une distribution normale de 5

La moyenne et l'écart type sont appropriés, mais si vous échantillonnez à partir de [0,1] avec une distribution uniforme, vous n'obtiendrez pas une valeur négative, j'ai donc rendu la valeur de bruit échantillonnée presque positive.

4.Batch Norm Toutes les données qui sortent du DataLoader créé ci-dessus sont une image réelle. vice versa

python



fake_img = G_model(z, label)

Ensuite, à partir des informations d'étiquette et du bruit obtenus à partir de DataLoader, nous créons de fausses images de la taille d'un lot.

5. Évitez des choses comme ReLU et Max Pooling où le gradient est clairsemé

LeakyReLU semble être efficace à la fois pour le générateur et le discriminateur, donc toutes les fonctions d'activation sont définies sur LeakyReLU. L'argument 0.2 a été suivi car de nombreuses implémentations ont adopté cette valeur.

6. Utilisez une étiquette bruyante pour l'étiquette correcte de D

L'étiquette du discriminateur est généralement 0 ou 1, mais nous ajoutons du bruit ici. Échantillonnez au hasard de vraies étiquettes de 0,7 à 1,2 et de fausses étiquettes de 0,0 à 0,3.

python



#La vraie étiquette est 1
y_real = torch.ones((batch_size, 1)).to(device)
D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Etiquette de bruit à mettre en D

#La fausse étiquette est 0
y_fake = torch.zeros((batch_size, 1)).to(device)
D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Etiquette de bruit à mettre en D

C'est la partie. J'utilise habituellement y_real / y_fake, et cette fois j'ai en fait utilisé D_y_real / D_y_fake.

9. Utilisez Adam comme méthode d'optimisation

Il s'agit d'un ancien article, donc un autre optimiseur tel que RAdam pourrait être meilleur maintenant.

14. Mettez le décrochage en G

Cette fois, je n'ai mis Dropout qu'une seule fois dans la couche linéaire de Generator. Cependant, il existe une théorie selon laquelle BatchNorm et Dropout ne sont pas compatibles l'un avec l'autre, donc je ne pense pas qu'il soit vraiment préférable de les mettre tous ensemble.

Afficher l'image créée par Generator

Avant d'entraîner le modèle, définissez une fonction pour afficher l'image créée par le générateur. Faites ceci et vérifiez le degré d'apprentissage du générateur pour chaque époque.

python



import os
from IPython.display import Image
from torchvision.utils import save_image
%matplotlib inline

def Generate_img(epoch, G_model, device, z_dim, noise, var_mode, labels, log_dir = 'logs_cGAN'):
    G_model.eval()
    
    with torch.no_grad():
        if var_mode == True:
            #Nombres aléatoires requis pour la génération
            noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
        else:
            noise = noise

        #Génération d'échantillons avec Generator
        samples = G_model(noise, labels).data.cpu()
        samples = (samples/2)+0.5
        save_image(samples,os.path.join(log_dir, 'epoch_%05d.png' % (epoch)), nrow = 7)
        img = Image('logs_cGAN/epoch_%05d.png' % (epoch))
        display(img)

Tout ce que vous avez à faire est de placer l'image que vous avez créée avec du bruit dans le générateur dans un dossier appelé logs_cGAN et de l'afficher. Il est supposé que le même nombre aléatoire sera utilisé chaque fois que var_mode vaut False.

Formation modèle

Former le modèle.

python



#Valeur de semence fixe pour assurer la reproductibilité
SEED = 1111
random.seed(SEED)
np.random.seed(SEED) 
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

#device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def model_run(num_epochs, batch_size = batch_size, dataloader = train_loader, device = device):
    
    #Dimension du bruit à mettre dans le générateur
    z_dim = 30
    var_mode = False #Indique s'il faut utiliser un nombre aléatoire différent chaque fois que vous voyez le résultat d'affichage
    #Nombres aléatoires requis pour la génération
    noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
    
    #Nombre de cours
    num_class = 49
    
    #Créez une étiquette à utiliser lors de l'essai de Generator
    labels = []
    for i in range(num_class):
        tmp = np.identity(num_class)[i]
        tmp = np.array(tmp, dtype = np.float32)
        labels.append(tmp)
    label = torch.Tensor(labels).to(device)
    
    #Définition du modèle
    D_model = Discriminator(num_class).to(device)
    G_model = Generator(z_dim, num_class).to(device)
    
    #Définition de la perte(L'argument est le train_Spécifié dans func)
    criterion = nn.BCELoss().to(device)
    
    #Définition de l'optimiseur
    D_optimizer = torch.optim.Adam(D_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
    G_optimizer = torch.optim.Adam(G_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
    
    D_loss_list = []
    G_loss_list = []
    
    all_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        
        D_loss, G_loss = train_func(D_model, G_model, batch_size, z_dim, num_class, criterion, 
                                    D_optimizer, G_optimizer, dataloader, device)

        D_loss_list.append(D_loss)
        G_loss_list.append(G_loss)
        
        secs = int(time.time() - start_time)
        mins = secs / 60
        secs = secs % 60
        
        #Voir les résultats par époque
        print('Epoch: %d' %(epoch + 1), " |Temps requis%d minutes%d secondes" %(mins, secs))
        print(f'\tLoss: {D_loss:.4f}(Discriminator)')
        print(f'\tLoss: {G_loss:.4f}(Generator)')
        
        if (epoch + 1) % 1 == 0:
            Generate_img(epoch, G_model, device, z_dim, noise, var_mode, label)
        
        #Créer un fichier de point de contrôle pour enregistrer le modèle
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch':epoch,
                'model_state_dict':G_model.state_dict(),
                'optimizer_state_dict':G_optimizer.state_dict(),
                'loss':G_loss,
            }, './checkpoint_cGAN/G_model_{}'.format(epoch + 1))
            
    return D_loss_list, G_loss_list

#Tournez le modèle
D_loss_list, G_loss_list = model_run(num_epochs = 100)

C'est un peu long, mais j'affiche le temps requis et la perte pour chaque époque, et j'enregistre le modèle.

résultat

Voyons la transition de la perte du générateur et du discriminateur.

python


import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure(figsize=(10,7))

loss = fig.add_subplot(1,1,1)

loss.plot(range(len(D_loss_list)),D_loss_list,label='Discriminator_loss')
loss.plot(range(len(G_loss_list)),G_loss_list,label='Generator_loss')

loss.set_xlabel('epoch')
loss.set_ylabel('loss')

loss.legend()
loss.grid()

fig.show()

cGAN-result.png

Depuis environ 20 époques, les deux pertes n'ont pas changé. Les pertes de Discriminator et Generator sont loin de 0, donc cela semble fonctionner raisonnablement bien. Au fait, si vous essayez de transformer les personnages générés en gifs dans l'ordre de 1 à 100 époques, cela ressemble à ceci. result_cGAN.gif

Le coin supérieur gauche est «A» et le coin inférieur droit est «ゝ». Il y a pas mal de différences selon le caractère, et il semble que "u", "ku", "sa", "so" et "hi" sont générés de manière stable et correcte, mais "na" et "yu" ont des transitions. C'est féroce.

Voici les résultats de la génération de 5 images pour chaque type. Epoch:5 epoch_5.png

Epoch:50 epoch_50.png

Epoch:100 epoch_100.png

En regardant cela seul, il semble qu'il ne soit pas préférable d'empiler Epoch. "Mu" semble être le meilleur à 5 époques, tandis que "ゑ" semble être le meilleur à 100 époques.

À propos, cela ressemble à ceci lorsque vous récupérez 5 données d'entraînement chacune de la même manière. train_data.png

Il y a des choses que même les gens modernes ne peuvent pas lire. «Su» et «mi» sont assez différents de leurs formes actuelles. En regardant cela, je pense que les performances du modèle sont assez bonnes.

Résumé

J'ai essayé de générer des caractères indésirables avec cGAN. Je pense qu'il reste encore beaucoup à faire dans la mise en œuvre, mais je pense que le résultat en lui-même est raisonnable. Cela fait longtemps, mais j'espère que cela aide même en partie.

De plus, certaines personnes ont implémenté des MNIST (nombres manuscrits) en utilisant PyTorch avec cGAN. Il y a de nombreuses parties différentes telles que la structure du modèle, donc je pense que cela est également utile.


Article de référence
J'ai essayé de générer des caractères manuscrits par apprentissage en profondeur [Pytorch x MNIST x CGAN]

finalement

A l'origine, j'étais légèrement motivé à penser "Est-il possible de générer des phrases manuscrites?", Donc je vais l'essayer à la fin.

Chargez le poids du modèle à partir du fichier de point de contrôle enregistré et essayez pkl une fois.

python



import cloudpickle
%matplotlib inline
#Spécifiez l'époque à récupérer
point = 50

#Définir la structure du modèle
z_dim = 30
num_class = 49
G = Generator(z_dim = z_dim, num_class = num_class)

#Extraire le point de contrôle
checkpoint = torch.load('./checkpoint_cGAN/G_model_{}'.format(point))

#Mettre les paramètres dans le générateur
G.load_state_dict(checkpoint['model_state_dict'])

#Restez en mode vérification
G.eval()

#Économisez avec cornichon
with open ('KMNIST_cGAN.pkl','wb')as f:
    cloudpickle.dump(G,f)

Il semble que vous puissiez le rendre pkl en utilisant un module appelé cloudpickle au lieu du pickle habituel.

Ouvrons ce fichier pkl et générons une phrase.

python



letter = 'Aiue Okakikuke Kosashi Suseso Tachi Nune no Hahifuhe Homami Mumemoya Yuyorari Rurerowa'

strs = input()
with open('KMNIST_cGAN.pkl','rb')as f:
    Generator = cloudpickle.load(f)
    
for i in range(len(str(strs))):
    noise = torch.normal(mean = 0.5, std = 0.2, size = (1, 30))
    str_index = letter.index(strs[i])
    tmp = np.identity(49)[str_index]
    tmp = np.array(tmp, dtype = np.float32)
    label = [tmp]
    
    img = Generator(noise, torch.Tensor(label))
    img = img.reshape((28,28))
    img = img.detach().numpy().tolist()
    
    if i == 0:
        comp_img = img
    else:
        comp_img.extend(img)
        
save_image(torch.tensor(comp_img), './sentence.png', nrow=len(str(strs)))
img = Image('./sentence.png')
display(img)

Le résultat ressemble à ceci. sentence.png

"Je ne sais plus rien" ...

Recommended Posts

Génération du caractère indésirable MNIST (KMNIST) avec cGAN (GAN conditionnel)
Enregistrez la sortie du GAN conditionnel pour chaque classe ~ Avec l'implémentation cGAN par PyTorch ~
Implémentation du GAN conditionnel avec chainer