J'ai essayé de générer un caractère indésirable MNIST en utilisant une sorte de GAN, cGAN (GAN conditionnel). Pour des aspects théoriques détaillés, veuillez vous référer aux liens qui seront utiles le cas échéant.
--Je veux implémenter GAN avec PyTorch ――Je souhaite créer un modèle capable de générer l'image souhaitée
J'espère que cela sera utile pour ceux qui aiment.
À propos de l'ensemble de données utilisé cette fois. KMNIST est un ensemble de données créé pour l'apprentissage automatique en tant que dérivé du "Japanese Classics Kuzuji Data Set" créé par le Humanities Open Data Sharing Center. Vous pouvez le télécharger depuis GitHub Link.
"KMNIST Dataset" (créé par CODH) "Japanese Classics Kuzuji Dataset" (Kokubunken et al.) Adapté doi: 10.20676 / 00000341
Comme ce MNIST (numéro manuscrit) que connaît tous ceux qui ont fait du machine learning, une image mesure 1 x 28px x 28px.
Les trois types d'ensembles de données suivants peuvent être téléchargés à partir du référentiel au format compressé numpy.array.
--kuzushiji-MNIST (10 caractères de hiragana) --kuzushiji-49 (49 caractères de hiragana) --kuzushiji-kanji (3832 kanji)
Parmi ceux-ci, "kuzushiji-49" sera utilisé cette fois. Il n'y a pas de raison profonde particulière, mais si 49 caractères hiragana peuvent être ciblés et générés, est-il possible de générer des phrases manuscrites? Je pensais que c'était une motivation légère.
Abordons brièvement GAN avant cGAN. GAN est une abréviation de "Generative Adversarial Network" (= réseau de génération hostile) et est une sorte de modèle de génération d'apprentissage en profondeur. C'est particulièrement efficace dans le domaine de la génération d'images, et je pense que le résultat de la génération d'images faciales de personnes qui n'existent pas dans le monde est célèbre.
Ce qui suit est un schéma de modèle approximatif du GAN. "G" signifie Generator et "D" signifie Discriminator.
Generator générera une fausse image aussi proche de la réalité que possible à partir du bruit. Le discriminateur fait la distinction entre l'image réelle (real_img) extraite de l'ensemble de données et la fausse image (fake_img) créée par le générateur (vrai ou faux).
En répétant cet apprentissage, le générateur tente de créer une image aussi proche que possible de la chose réelle que le discriminateur ne peut pas détecter, et le discriminateur essaie de détecter le faux créé par le générateur et la chose réelle dérivée de l'ensemble de données, de sorte que le générateur est généré. La précision augmentera.
Article de référence
GAN (1) Comprendre la structure de base que je n'entends plus
Les articles liés au GAN sont organisés dans This GitHub Repository.
Ensuite, je voudrais parler du GAN conditionnel utilisé cette fois. En termes simples, c'est ** "GAN qui peut générer l'image souhaitée" **. L'idée est simple, c'est comme décider de l'image à générer en ajoutant des informations d'étiquette à l'entrée du discriminateur et du générateur.
Le document original est ici
C'est la même chose qu'un GAN normal sauf que ** "Entrez le libellé pendant l'entraînement" **. De plus, bien que les informations d'étiquette soient utilisées, Discriminator détermine uniquement "si l'image est authentique".
Article de référence
Implémentation de GAN (6) GAN conditionnel que je n'entends plus
Passons maintenant à l'implémentation.
J'ai installé jupyterlab et fonctionne sur Ubuntu 18.04.
Importez le module requis
python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import random
Téléchargez les données au format numpy depuis le github de KMNIST.
Avec jupyterlab, après l'ouverture du Terminal et le déplacement du référentiel
wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-imgs.npz
wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-labels.npz
Ensuite, vous pouvez télécharger l'image et l'étiquette.
Au fait, s'il s'agit de KMNIST avec 10 caractères de hiragana, il est inclus par défaut dans torchvision. Si cela ne vous dérange pas, tout comme MNIST normal
python
transform = transforms.Compose(
[transforms.ToTensor(),
])
train_data_10 = torchvision.datasets.KMNIST(root='./data', train=True,download=True,transform=transform)
Vous pouvez l'utiliser si vous le faites.
Si vous souhaitez créer votre propre jeu de données personnalisé avec PyTorch, vous devez définir vous-même le prétraitement. Le prétraitement basé sur l'image est principalement contenu dans torchvision.transforms
, donc j'utilise souvent ceci, mais vous pouvez aussi créer le vôtre.
python
class Transform(object):
def __init__(self):
pass
def __call__(self, sample):
sample = np.array(sample, dtype = np.float32)
sample = torch.tensor(sample)
return (sample/127.5)-1
transform = Transform()
La plupart des fractions gérées par numpy sont np.float64
(nombre à virgule flottante 64 bits), mais PyTorch gère la valeur fractionnaire avec le nombre à virgule flottante 32 bits par défaut, donc une erreur se produira si elles ne sont pas alignées.
De plus, le traitement pour normaliser la valeur de luminosité de l'image dans la plage de [-1,1] est effectué ici. En effet, "Tanh" est utilisé dans la dernière couche de la sortie du générateur qui sortira plus tard, de sorte que la valeur de luminosité de l'image réelle sera ajustée en conséquence.
Ensuite, nous définirons la classe Dataset.
Il s'agit d'un module qui renvoie un ensemble de données et d'étiquettes, et retourne les données prétraitées par la transformation
définie précédemment lors de la récupération des données.
python
from tqdm import tqdm
class dataset_full(torch.utils.data.Dataset):
def __init__(self, img, label, transform=None):
self.transform = transform
self.data_num = len(img)
self.data = []
self.label = []
for i in tqdm(range(self.data_num)):
self.data.append([img[i]])
self.label.append(label[i])
self.data_num = len(self.data)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.data[idx]
out_label = np.identity(49)[self.label[idx]]
out_label = np.array(out_label, dtype = np.float32)
if self.transform:
out_data = self.transform(out_data)
return out_data, out_label
Si vous mettez le premier tqdm
, la progression sera affichée comme un graphique à barres lorsque vous activez l'instruction for, mais cela n'a rien à voir avec le cGAN lui-même.
J'utilise np.identity
pour créer un vecteur one-hot d'une longueur de 49.
Créez un Dataset à l'aide des classes Transform
, Dataset
implémentées à partir des données que vous avez téléchargées précédemment.
python
path = %pwd
train_img = np.load('{}/k49-train-imgs.npz'.format(path))
train_img = train_img['arr_0']
train_label = np.load('{}/k49-train-labels.npz'.format(path))
train_label = train_label['arr_0']
train_data = dataset_full(train_img, train_label, transform=transform)
Si vous avez mis dans le tqdm plus tôt, la progression sera affichée lorsque vous exécutez cela. La plupart des données sont de 232 625, mais je ne pense pas que cela prendra longtemps.
Nous avons un ensemble de données, mais nous ne récupérons pas les données directement à partir de cet ensemble de données lors de l'entraînement du modèle. Puisque nous entraînons lot par lot, nous définirons un DataLoader qui renverra des données de taille de lot.
python
batch_size = 256
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers=2)
Si vous définissez shuffle = True
, les données extraites de DataLoader seront aléatoires. num_workers
est un argument qui spécifie le nombre de cœurs de processeur utilisés par DataLoader, et n'est pas particulièrement pertinent pour le cGAN lui-même.
Le Transform-Dataset-DataLoader jusqu'à présent est résumé dans les articles suivants.
Article de référence
Vérifiez le fonctionnement de base des transformations PyTorch / Dataset / DataLoader
Je vais fabriquer le corps du modèle. Generator crée une fausse image (fake_img) à partir du bruit et des étiquettes.
La méthode de mise en œuvre est assez différente selon la personne, mais la structure du Générateur créé cette fois-ci est la suivante.
(C'est écrit à la main, mais je suis désolé ...)
Dans l'entrée, z_dim
(dimension du bruit) est 30 et num_class
(nombre de classes) est 49 caractères hiragana, il est donc défini sur 49. La fausse image de la sortie a la forme de 1 (canal) x 28 (px) x 28 (px).
python
class Generator(nn.Module):
def __init__(self, z_dim, num_class):
super(Generator, self).__init__()
self.fc1 = nn.Linear(z_dim, 300)
self.bn1 = nn.BatchNorm1d(300)
self.LReLU1 = nn.LeakyReLU(0.2)
self.fc2 = nn.Linear(num_class, 1500)
self.bn2 = nn.BatchNorm1d(1500)
self.LReLU2 = nn.LeakyReLU(0.2)
self.fc3 = nn.Linear(1800, 128 * 7 * 7)
self.bn3 = nn.BatchNorm1d(128 * 7 * 7)
self.bo1 = nn.Dropout(p=0.5)
self.LReLU3 = nn.LeakyReLU(0.2)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), #Changez le nombre de canaux de 128 à 64.
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), #Changement du nombre de canaux de 64 à 1
nn.Tanh(),
)
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.ConvTranspose2d):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
def forward(self, noise, labels):
y_1 = self.fc1(noise)
y_1 = self.bn1(y_1)
y_1 = self.LReLU1(y_1)
y_2 = self.fc2(labels)
y_2 = self.bn2(y_2)
y_2 = self.LReLU2(y_2)
x = torch.cat([y_1, y_2], 1)
x = self.fc3(x)
x = self.bo1(x)
x = self.LReLU3(x)
x = x.view(-1, 128, 7, 7)
x = self.deconv(x)
return x
Vient ensuite Discriminator. Le discriminateur entre l'image authentique / fausse et ses informations d'étiquette et détermine si elle est authentique ou fausse.
La structure du discriminateur créé cette fois est la suivante.
ʻImg(image d'entrée) vaut 1 (canal) x 28 (px) x 28 (px) pour les authentiques et les faux, et
labels` (étiquette d'entrée) est un vecteur unidimensionnel à 49 dimensions. La sortie détermine si elle est authentique ou non avec une valeur de 0 à 1.
Concattez l'image et étiquetez les informations dans le sens du canal avec cat
au milieu. Je pense que l'article de cGAN que j'ai mentionné plus tôt est facile à comprendre sur ce domaine.
python
class Discriminator(nn.Module):
def __init__(self, num_class):
super(Discriminator, self).__init__()
self.num_class = num_class
self.conv = nn.Sequential(
nn.Conv2d(num_class + 1, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(128),
)
self.fc = nn.Sequential(
nn.Linear(128 * 7 * 7, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 1),
nn.Sigmoid(),
)
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
def forward(self, img, labels):
y_2 = labels.view(-1, self.num_class, 1, 1)
y_2 = y_2.expand(-1, -1, 28, 28)
x = torch.cat([img, y_2], 1)
x = self.conv(x)
x = x.view(-1, 128 * 7 * 7)
x = self.fc(x)
return x
1 Créez une fonction pour calculer l'époque.
python
def train_func(D_model, G_model, batch_size, z_dim, num_class, criterion,
D_optimizer, G_optimizer, data_loader, device):
#Mode entraînement
D_model.train()
G_model.train()
#La vraie étiquette est 1
y_real = torch.ones((batch_size, 1)).to(device)
D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Etiquette de bruit à mettre en D
#La fausse étiquette est 0
y_fake = torch.zeros((batch_size, 1)).to(device)
D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Etiquette de bruit à mettre en D
#Initialisation de la perte
D_running_loss = 0
G_running_loss = 0
#Calcul lot par lot
for batch_idx, (data, labels) in enumerate(data_loader):
#Ignorer si inférieur à la taille du lot
if data.size()[0] != batch_size:
break
#Création de bruit
z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Moyenne 0.Générer des nombres aléatoires selon une distribution normale de 5
real_img, label, z = data.to(device), labels.to(device), z.to(device)
#Mise à jour du discriminateur
D_optimizer.zero_grad()
#Mettre une image réelle dans Discriminator et propager vers l'avant ⇒ Calcul des pertes
D_real = D_model(real_img, label)
D_real_loss = criterion(D_real, D_y_real)
#Mettre l'image créée en mettant du bruit dans Generator dans Discriminator et propager vers l'avant ⇒ Calcul de la perte
fake_img = G_model(z, label)
D_fake = D_model(fake_img.detach(), label) #fake_Stop Loss calculé dans les images pour qu'il ne se propage pas vers Generator
D_fake_loss = criterion(D_fake, D_y_fake)
#Minimiser la somme de deux pertes
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
D_optimizer.step()
D_running_loss += D_loss.item()
#Mise à jour du générateur
G_optimizer.zero_grad()
#L'image créée en mettant du bruit dans le Générateur est placée dans le Discriminateur et propagée vers l'avant ⇒ La partie détectée devient Perte
fake_img_2 = G_model(z, label)
D_fake_2 = D_model(fake_img_2, label)
#G perte(max(log D)Optimisé avec)
G_loss = -criterion(D_fake_2, y_fake)
G_loss.backward()
G_optimizer.step()
G_running_loss += G_loss.item()
D_running_loss /= len(data_loader)
G_running_loss /= len(data_loader)
return D_running_loss, G_running_loss
Le «critère» qui apparaît dans l'argument est la classe de perte (dans ce cas, l'entropie croisée binaire). Ce que nous faisons avec cette fonction est dans l'ordre
est.
C'est un peu vieux, mais cette implémentation intègre l'ingéniosité qui apparaît dans "Comment former un GAN" à NIPS2016 pour réussir l'apprentissage du GAN. Lien GitHub
Article de référence
14 Techniques for Learning GAN (Generative Adversarial Networks)
Quand j'ai créé la classe Dataset
python
return (sample/127.5)-1
Est-ce. De plus, la dernière couche de Generator est nn.Tanh ()
.
python
#G perte(max(log D)Optimisé avec)
G_loss = -criterion(D_fake_2, y_fake)
Est-ce. «D_fake_2» est le jugement de Discriminator, et «y_fake» est un vecteur 128 × 1 0.
Échantillonnez le bruit à mettre dans le générateur à partir d'une distribution normale au lieu d'une distribution uniforme.
python
#Création de bruit
z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Moyenne 0.Générer des nombres aléatoires selon une distribution normale de 5
La moyenne et l'écart type sont appropriés, mais si vous échantillonnez à partir de [0,1] avec une distribution uniforme, vous n'obtiendrez pas une valeur négative, j'ai donc rendu la valeur de bruit échantillonnée presque positive.
4.Batch Norm Toutes les données qui sortent du DataLoader créé ci-dessus sont une image réelle. vice versa
python
fake_img = G_model(z, label)
Ensuite, à partir des informations d'étiquette et du bruit obtenus à partir de DataLoader, nous créons de fausses images de la taille d'un lot.
LeakyReLU semble être efficace à la fois pour le générateur et le discriminateur, donc toutes les fonctions d'activation sont définies sur LeakyReLU. L'argument 0.2 a été suivi car de nombreuses implémentations ont adopté cette valeur.
L'étiquette du discriminateur est généralement 0 ou 1, mais nous ajoutons du bruit ici. Échantillonnez au hasard de vraies étiquettes de 0,7 à 1,2 et de fausses étiquettes de 0,0 à 0,3.
python
#La vraie étiquette est 1
y_real = torch.ones((batch_size, 1)).to(device)
D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Etiquette de bruit à mettre en D
#La fausse étiquette est 0
y_fake = torch.zeros((batch_size, 1)).to(device)
D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Etiquette de bruit à mettre en D
C'est la partie. J'utilise habituellement y_real
/ y_fake
, et cette fois j'ai en fait utilisé D_y_real
/ D_y_fake
.
Il s'agit d'un ancien article, donc un autre optimiseur tel que RAdam pourrait être meilleur maintenant.
Cette fois, je n'ai mis Dropout qu'une seule fois dans la couche linéaire de Generator. Cependant, il existe une théorie selon laquelle BatchNorm et Dropout ne sont pas compatibles l'un avec l'autre, donc je ne pense pas qu'il soit vraiment préférable de les mettre tous ensemble.
Avant d'entraîner le modèle, définissez une fonction pour afficher l'image créée par le générateur. Faites ceci et vérifiez le degré d'apprentissage du générateur pour chaque époque.
python
import os
from IPython.display import Image
from torchvision.utils import save_image
%matplotlib inline
def Generate_img(epoch, G_model, device, z_dim, noise, var_mode, labels, log_dir = 'logs_cGAN'):
G_model.eval()
with torch.no_grad():
if var_mode == True:
#Nombres aléatoires requis pour la génération
noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
else:
noise = noise
#Génération d'échantillons avec Generator
samples = G_model(noise, labels).data.cpu()
samples = (samples/2)+0.5
save_image(samples,os.path.join(log_dir, 'epoch_%05d.png' % (epoch)), nrow = 7)
img = Image('logs_cGAN/epoch_%05d.png' % (epoch))
display(img)
Tout ce que vous avez à faire est de placer l'image que vous avez créée avec du bruit dans le générateur dans un dossier appelé logs_cGAN
et de l'afficher.
Il est supposé que le même nombre aléatoire sera utilisé chaque fois que var_mode vaut False.
Former le modèle.
python
#Valeur de semence fixe pour assurer la reproductibilité
SEED = 1111
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
#device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def model_run(num_epochs, batch_size = batch_size, dataloader = train_loader, device = device):
#Dimension du bruit à mettre dans le générateur
z_dim = 30
var_mode = False #Indique s'il faut utiliser un nombre aléatoire différent chaque fois que vous voyez le résultat d'affichage
#Nombres aléatoires requis pour la génération
noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
#Nombre de cours
num_class = 49
#Créez une étiquette à utiliser lors de l'essai de Generator
labels = []
for i in range(num_class):
tmp = np.identity(num_class)[i]
tmp = np.array(tmp, dtype = np.float32)
labels.append(tmp)
label = torch.Tensor(labels).to(device)
#Définition du modèle
D_model = Discriminator(num_class).to(device)
G_model = Generator(z_dim, num_class).to(device)
#Définition de la perte(L'argument est le train_Spécifié dans func)
criterion = nn.BCELoss().to(device)
#Définition de l'optimiseur
D_optimizer = torch.optim.Adam(D_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
G_optimizer = torch.optim.Adam(G_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
D_loss_list = []
G_loss_list = []
all_time = time.time()
for epoch in range(num_epochs):
start_time = time.time()
D_loss, G_loss = train_func(D_model, G_model, batch_size, z_dim, num_class, criterion,
D_optimizer, G_optimizer, dataloader, device)
D_loss_list.append(D_loss)
G_loss_list.append(G_loss)
secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60
#Voir les résultats par époque
print('Epoch: %d' %(epoch + 1), " |Temps requis%d minutes%d secondes" %(mins, secs))
print(f'\tLoss: {D_loss:.4f}(Discriminator)')
print(f'\tLoss: {G_loss:.4f}(Generator)')
if (epoch + 1) % 1 == 0:
Generate_img(epoch, G_model, device, z_dim, noise, var_mode, label)
#Créer un fichier de point de contrôle pour enregistrer le modèle
if (epoch + 1) % 5 == 0:
torch.save({
'epoch':epoch,
'model_state_dict':G_model.state_dict(),
'optimizer_state_dict':G_optimizer.state_dict(),
'loss':G_loss,
}, './checkpoint_cGAN/G_model_{}'.format(epoch + 1))
return D_loss_list, G_loss_list
#Tournez le modèle
D_loss_list, G_loss_list = model_run(num_epochs = 100)
C'est un peu long, mais j'affiche le temps requis et la perte pour chaque époque, et j'enregistre le modèle.
Voyons la transition de la perte du générateur et du discriminateur.
python
import matplotlib.pyplot as plt
%matplotlib inline
fig = plt.figure(figsize=(10,7))
loss = fig.add_subplot(1,1,1)
loss.plot(range(len(D_loss_list)),D_loss_list,label='Discriminator_loss')
loss.plot(range(len(G_loss_list)),G_loss_list,label='Generator_loss')
loss.set_xlabel('epoch')
loss.set_ylabel('loss')
loss.legend()
loss.grid()
fig.show()
Depuis environ 20 époques, les deux pertes n'ont pas changé. Les pertes de Discriminator et Generator sont loin de 0, donc cela semble fonctionner raisonnablement bien. Au fait, si vous essayez de transformer les personnages générés en gifs dans l'ordre de 1 à 100 époques, cela ressemble à ceci.
Le coin supérieur gauche est «A» et le coin inférieur droit est «ゝ». Il y a pas mal de différences selon le caractère, et il semble que "u", "ku", "sa", "so" et "hi" sont générés de manière stable et correcte, mais "na" et "yu" ont des transitions. C'est féroce.
Voici les résultats de la génération de 5 images pour chaque type. Epoch:5
Epoch:50
Epoch:100
En regardant cela seul, il semble qu'il ne soit pas préférable d'empiler Epoch. "Mu" semble être le meilleur à 5 époques, tandis que "ゑ" semble être le meilleur à 100 époques.
À propos, cela ressemble à ceci lorsque vous récupérez 5 données d'entraînement chacune de la même manière.
Il y a des choses que même les gens modernes ne peuvent pas lire. «Su» et «mi» sont assez différents de leurs formes actuelles. En regardant cela, je pense que les performances du modèle sont assez bonnes.
J'ai essayé de générer des caractères indésirables avec cGAN. Je pense qu'il reste encore beaucoup à faire dans la mise en œuvre, mais je pense que le résultat en lui-même est raisonnable. Cela fait longtemps, mais j'espère que cela aide même en partie.
De plus, certaines personnes ont implémenté des MNIST (nombres manuscrits) en utilisant PyTorch avec cGAN. Il y a de nombreuses parties différentes telles que la structure du modèle, donc je pense que cela est également utile.
Article de référence
J'ai essayé de générer des caractères manuscrits par apprentissage en profondeur [Pytorch x MNIST x CGAN]
A l'origine, j'étais légèrement motivé à penser "Est-il possible de générer des phrases manuscrites?", Donc je vais l'essayer à la fin.
Chargez le poids du modèle à partir du fichier de point de contrôle enregistré et essayez pkl une fois.
python
import cloudpickle
%matplotlib inline
#Spécifiez l'époque à récupérer
point = 50
#Définir la structure du modèle
z_dim = 30
num_class = 49
G = Generator(z_dim = z_dim, num_class = num_class)
#Extraire le point de contrôle
checkpoint = torch.load('./checkpoint_cGAN/G_model_{}'.format(point))
#Mettre les paramètres dans le générateur
G.load_state_dict(checkpoint['model_state_dict'])
#Restez en mode vérification
G.eval()
#Économisez avec cornichon
with open ('KMNIST_cGAN.pkl','wb')as f:
cloudpickle.dump(G,f)
Il semble que vous puissiez le rendre pkl en utilisant un module appelé cloudpickle
au lieu du pickle
habituel.
Ouvrons ce fichier pkl et générons une phrase.
python
letter = 'Aiue Okakikuke Kosashi Suseso Tachi Nune no Hahifuhe Homami Mumemoya Yuyorari Rurerowa'
strs = input()
with open('KMNIST_cGAN.pkl','rb')as f:
Generator = cloudpickle.load(f)
for i in range(len(str(strs))):
noise = torch.normal(mean = 0.5, std = 0.2, size = (1, 30))
str_index = letter.index(strs[i])
tmp = np.identity(49)[str_index]
tmp = np.array(tmp, dtype = np.float32)
label = [tmp]
img = Generator(noise, torch.Tensor(label))
img = img.reshape((28,28))
img = img.detach().numpy().tolist()
if i == 0:
comp_img = img
else:
comp_img.extend(img)
save_image(torch.tensor(comp_img), './sentence.png', nrow=len(str(strs)))
img = Image('./sentence.png')
display(img)
Le résultat ressemble à ceci.
"Je ne sais plus rien" ...