Pour ma propre pratique, j'ai implémenté et formé CVAE, qui est un type d'apprentissage profond. Cet article est une description de niveau mémo et est rédigé en partant du principe que vous connaissez la VAE. Notez s'il vous plaît.
<détails> Il est également implémenté à l'aide de Jupyter Notebook. Voici quelques pages auxquelles j'ai fait référence lors de la mise en œuvre. En outre, je me réfère également à l'exemple de mise en œuvre de Pytorch. ** CVAE (Conditional Variational Auto Encoder) ** est une méthode avancée de VAE.
Dans VAE normal, les données sont entrées dans l'encodeur et les variables latentes sont entrées dans le décodeur, mais dans CVAE, l'état des données est ajouté à celles-ci. Cela vous donne les avantages suivants: Cette fois, nous allons implémenter CVAE avec Pytorch et former MNIST (ensemble de données de caractères manuscrits). résultat Voici une liste de points de mise en œuvre et d'apprentissage. --Utilisez 6000 données d'apprentissage de VAE a deux applications, la suppression de dimension et la génération de données, mais cette fois nous nous concentrerons sur la génération de données.
Pensez à créer une nouvelle image manuscrite à l'aide du décodeur CVAE que vous avez appris précédemment. Les informations d'étiquette fournies au décodeur sont fixées à "5", 100 nombres aléatoires qui suivent la distribution normale standard sont générés, et les images correspondant à chacun sont générées. résultat Certains d'entre eux sont déformés, mais nous sommes capables de générer diverses images "5". J'ai recherché les nombres en gras dans l'image de test de Il est écrit très épais comme "4".
Utilisez Encoder pour trouver la variable latente correspondant à ces données. résultat Ce vecteur à 16 dimensions contient les informations de l'image de ** autre que l'étiquette ** donnée au moment de la formation. En d'autres termes, vous devriez avoir l'information "très épaisse", pas l'information "c'est sous la forme de 4". Par conséquent, en utilisant cette variable latente, essayez de générer une image tout en modifiant les informations d'étiquette fournies au décodeur. résultat "2" est un peu suspect, mais je suis capable de générer une image avec des nombres épais. Je connaissais CVAE depuis longtemps, mais c'était la première fois que je l'implémentais. Je suis content que cela ait fonctionné. Il est important non seulement de le connaître mais aussi de le mettre en œuvre.
Certaines des images générées n'étaient pas jolies, mais elles peuvent être résolues en utilisant la convolution ou la convolution de translocation dans le réseau VAE.
Bien que cette fois omis, le système VAE reconnaît qu'il est important d'analyser quelles entités sont mappées et où dans l'espace de faible dimension. J'aimerais faire cette analyse cette fois. [^ 1]: C'est pour que les données avec toutes les étiquettes existent dans le mini-lot afin que l'image du mini-lot par Encoder suive la distribution normale standard dans l'espace variable latent.
Recommended Posts
Article de référence
Qu'est-ce que CVAE
Mise en œuvre et apprentissage
python
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline
DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50
# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
class CVAE(nn.Module):
def __init__(self, zdim):
super().__init__()
self._zdim = zdim
self._in_units = 28 * 28
hidden_units = 512
self._encoder = nn.Sequential(
nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
)
self._to_mean = nn.Linear(hidden_units, zdim)
self._to_lnvar = nn.Linear(hidden_units, zdim)
self._decoder = nn.Sequential(
nn.Linear(zdim + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, self._in_units),
nn.Sigmoid()
)
def encode(self, x, labels):
in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
in_[:, :self._in_units] = x
in_[:, self._in_units:] = labels
h = self._encoder(in_)
mean = self._to_mean(h)
lnvar = self._to_lnvar(h)
return mean, lnvar
def decode(self, z, labels):
in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
in_[:, :self._zdim] = z
in_[:, self._zdim:] = labels
return self._decoder(in_)
def to_onehot(label):
return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]
# Train
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for e in range(NUM_EPOCHS):
train_loss = 0
for i, (images, labels) in enumerate(train_loader):
labels = to_onehot(labels)
# Reconstruction images
# Encode images
x = images.view(-1, 28*28*1).to(DEVICE)
mean, lnvar = model.encode(x, labels)
std = lnvar.exp().sqrt()
epsilon = torch.randn(ZDIM, device=DEVICE)
# Decode latent variables
z = mean + std * epsilon
y = model.decode(z, labels)
# Compute loss
kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
loss = (-1 * kld + bce).mean()
# Update model
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * x.shape[0]
print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')
epoch: 1 epoch_loss: 200.2185905436198
epoch: 2 epoch_loss: 160.22688263346353
epoch: 3 epoch_loss: 148.69330817057292
#Omission
epoch: 48 epoch_loss: 98.95304524739583
epoch: 49 epoch_loss: 98.6720672281901
epoch: 50 epoch_loss: 98.65486107177735
torchvision.datasets.MNIST
pour l'apprentissage et définissez le nombre d'époques sur 50.
et
decodesans implémenter
forward`
--Convertissez l'étiquette de l'ensemble de données (numéro écrit) en un vecteur unique et ajoutez-la aux entrées de l'encodeur et du décodeur
--La taille du mini-lot au moment de l'apprentissage est de 256 [^ 1]
--Comprend un MLP simple pour l'encodeur et le décodeur
--Réglez la dimension de la variable latente à 16.Génération d'images par CVAE
Génération d'images "5"
python
# Generation data with label '5'
NUM_GENERATION = 100
os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
label = torch.tensor([5], device=DEVICE)
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
# Save image
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
plt.close(fig)
Génération d'une image numérique épaisse
torchvision.datasets.MNIST
.
L'image suivante est la 49e image de l'ensemble de données.python
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
transform=transforms.ToTensor(),
download=True,
)
target_image, label = list(test_dataset)[48]
x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
mean, _ = model.encode(x, to_onehot(label))
z = mean
print(f'z = {z.cpu().detach().numpy().squeeze()}')
z = [ 0.7933388 2.4768877 0.49229255 -0.09540698 -1.7999544 0.03376897
0.01600834 1.3863252 0.14656337 -0.14543885 0.04157912 0.13938689
-0.2016176 0.5204378 -0.08096244 1.0930295 ]
python
os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/fat/img{label}')
plt.close(fig)
en conclusion