For my own practice, I implemented and learned CVAE, which is a type of deep learning. This article is a memo-level description and is written on the assumption that you have knowledge of VAE. Please note.
It is also implemented using Jupyter Notebook.
Here are the pages that I referred to when implementing.
-[Qiita] Variational Autoencoder Thorough Explanation -[Qiita] Journey around the deep generative model (2): VAE
In addition, I also refer to the example implementation of Pytorch.
** CVAE (Conditional Variational Auto Encoder) ** is an advanced method of VAE. In normal VAE, data is input to Encoder and latent variables are input to Decoder, but in CVAE, the state of data is added to these. This gives you the following benefits:
--When deleting dimensions with Encoder, features other than data labels can be reflected. --When generating data with Decoder, you can specify the state of the desired data.
This time, we will implement CVAE with Pytorch and train MNIST (handwritten character data set).
python
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline
DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50
# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
class CVAE(nn.Module):
def __init__(self, zdim):
super().__init__()
self._zdim = zdim
self._in_units = 28 * 28
hidden_units = 512
self._encoder = nn.Sequential(
nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
)
self._to_mean = nn.Linear(hidden_units, zdim)
self._to_lnvar = nn.Linear(hidden_units, zdim)
self._decoder = nn.Sequential(
nn.Linear(zdim + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, self._in_units),
nn.Sigmoid()
)
def encode(self, x, labels):
in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
in_[:, :self._in_units] = x
in_[:, self._in_units:] = labels
h = self._encoder(in_)
mean = self._to_mean(h)
lnvar = self._to_lnvar(h)
return mean, lnvar
def decode(self, z, labels):
in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
in_[:, :self._zdim] = z
in_[:, self._zdim:] = labels
return self._decoder(in_)
def to_onehot(label):
return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]
# Train
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for e in range(NUM_EPOCHS):
train_loss = 0
for i, (images, labels) in enumerate(train_loader):
labels = to_onehot(labels)
# Reconstruction images
# Encode images
x = images.view(-1, 28*28*1).to(DEVICE)
mean, lnvar = model.encode(x, labels)
std = lnvar.exp().sqrt()
epsilon = torch.randn(ZDIM, device=DEVICE)
# Decode latent variables
z = mean + std * epsilon
y = model.decode(z, labels)
# Compute loss
kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
loss = (-1 * kld + bce).mean()
# Update model
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * x.shape[0]
print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')
result
epoch: 1 epoch_loss: 200.2185905436198
epoch: 2 epoch_loss: 160.22688263346353
epoch: 3 epoch_loss: 148.69330817057292
#Omission
epoch: 48 epoch_loss: 98.95304524739583
epoch: 49 epoch_loss: 98.6720672281901
epoch: 50 epoch_loss: 98.65486107177735
The following is a list of implementation and learning points.
--Use 6000 training data of torchvision.datasets.MNIST
for learning and set the number of epochs to 50.
--Design a CVAE class with Encoder and Decoder and implement ʻencode and
decodemethods without implementing
forward`
--Convert the dataset label (written number) to a one-hot vector and add it to the Encoder and Decoder inputs
--The mini-batch size for learning is a large 256 [^ 1]
--Consists of simple MLP for both Encoder and Decoder
--Set the dimension of the latent variable to 16.
VAE has two applications, dimension deletion and data generation, but this time we will focus on data generation. Consider creating a new handwritten image using the CVAE Decoder that you learned earlier.
The label information given to the Decoder is fixed to "5", 100 random numbers that follow the standard normal distribution are generated, and the corresponding image is generated.
python
# Generation data with label '5'
NUM_GENERATION = 100
os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
label = torch.tensor([5], device=DEVICE)
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
# Save image
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
plt.close(fig)
result
Some of them are out of shape, but we are able to generate various "5" images.
I searched for the bold numbers in the test image of torchvision.datasets.MNIST
.
The following image is the 49th image in the dataset.
It is written very thickly as "4". Use Encoder to find the latent variable corresponding to this data.
python
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
transform=transforms.ToTensor(),
download=True,
)
target_image, label = list(test_dataset)[48]
x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
mean, _ = model.encode(x, to_onehot(label))
z = mean
print(f'z = {z.cpu().detach().numpy().squeeze()}')
result
z = [ 0.7933388 2.4768877 0.49229255 -0.09540698 -1.7999544 0.03376897
0.01600834 1.3863252 0.14656337 -0.14543885 0.04157912 0.13938689
-0.2016176 0.5204378 -0.08096244 1.0930295 ]
This 16-dimensional vector has the image information of the label ** other than ** given at the time of learning. In other words, it should have the information "very thick", not the information "it is in the form of 4".
Therefore, using this latent variable, let's generate an image while changing the label information given to the Decoder.
python
os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/fat/img{label}')
plt.close(fig)
result
"2" is a little suspicious, but I was able to generate an image with thick numbers.
I knew about CVAE for a long time, but this was the first time I implemented it. I'm glad it worked. It is important to implement it as well as knowledge. Some of the generated images didn't look pretty, but they may be resolved by using convolution or transpose convolution in the VAE network. Although omitted this time, the VAE system recognizes that it is important to analyze which features are mapped where in the low-dimensional space. I would like to do that analysis this time.
[^ 1]: This is to ensure that the data with all labels exists in the mini-batch so that the image of the mini-batch by Encoder follows a standard normal distribution in the latent variable space.
Recommended Posts