Deep Kernel Learning is a combination of deep learning and Gaussian process, and is one of Bayesian deep learning. As a method, a deep kernel is created by using the features output from the deep neural network (DNN) as the input of the kernel in the Gaussian process. The formula is as follows.
k_{deep}(x,x') = k(f(x),f(x'))
Since the Gaussian process is equivalent to a neural network with infinite units, it looks like it was added to the end of DNN. As I tried in the previous article (https://qiita.com/takeajioka/items/f24d58d2b13017ab2b18), it is important to optimize kernel hyperparameters during the Gaussian process. Deep Kernel Learning seems to optimize and learn DNN parameters and kernel hyperparameters at the same time.
Please refer to the following paper for details. [1] Deep Kernel Learning, 2015, Andrew G. Wilson et al.,https://arxiv.org/abs/1511.02222 [2] Stochastic Variational Deep Kernel Learning, 2016, Andrew G. Wilson et al., https://arxiv.org/abs/1611.00336
In Pyro, you can easily create a deep kernel by using the gp.kernels.Warping class. There is a Deep Kernel Learning code in Pyro Official Tutorial, so let's learn by referring to it.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import pyro
import pyro.contrib.gp as gp
import pyro.infer as infer
Since MNIST has a large amount of data, we will learn it in a mini-batch. Set the dataset.
batch_size = 100
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
First, prepare a normal DNN model.
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
Wrap the kernel around it to create a deep kernel.
rbf = gp.kernels.RBF(input_dim=10, lengthscale=torch.ones(10))
deep_kernel = gp.kernels.Warping(rbf, iwarping_fn=CNN())
A sparse approximation is used to reduce the computational cost of the Gaussian process. In the sparse approximation, the inducing point is used, but this time we will use training data for one batch size.
Xu, _ = next(iter(trainloader))
likelihood = gp.likelihoods.MultiClass(num_classes=10)
gpmodule = gp.models.VariationalSparseGP(X=None, y=None, kernel=deep_kernel, Xu=Xu, likelihood=likelihood, latent_shape=torch.Size([10]), num_data=60000)
optimizer = torch.optim.Adam(gpmodule.parameters(), lr=0.01)
elbo = infer.TraceMeanField_ELBO()
loss_fn = elbo.differentiable_loss
Defines a mini-batch learning function.
def train(train_loader, gpmodule, optimizer, loss_fn, epoch):
total_loss = 0
for data, target in train_loader:
gpmodule.set_data(data, target)
optimizer.zero_grad()
loss = loss_fn(gpmodule.model, gpmodule.guide)
loss.backward()
optimizer.step()
total_loss += loss
return total_loss / len(train_loader)
def test(test_loader, gpmodule):
correct = 0
for data, target in test_loader:
f_loc, f_var = gpmodule(data)
pred = gpmodule.likelihood(f_loc, f_var)
correct += pred.eq(target).long().sum().item()
return 100. * correct / len(test_loader.dataset)
Do learning.
import time
losses = []
accuracy = []
epochs = 10
for epoch in range(epochs):
start_time = time.time()
loss = train(trainloader, gpmodule, optimizer, loss_fn, epoch)
losses.append(loss)
with torch.no_grad():
acc = test(testloader, gpmodule)
accuracy.append(acc)
print("Amount of time spent for epoch {}: {}s\n".format(epoch+1, int(time.time() - start_time)))
print("loss:{:.2f}, accuracy:{}".format(losses[-1],accuracy[-1]))
I was able to learn one epoch in about 30 seconds. The final accuracy was 96.23%. (It seems that it can be up to 99.41% in the official tutorial.) Display the learning curve.
import matplotlib.pyplot as plt
plt.subplot(2,1,1)
plt.plot(losses)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.subplot(2,1,2)
plt.plot(accuracy)
plt.xlabel("epoch")
plt.ylabel("accuracy")
Let's look at the test image and the predicted output side by side.
data, target = next(iter(testloader))
f_loc, f_var = gpmodule(data)
pred = gpmodule.likelihood(f_loc, f_var)
for i in range(len(data)):
plt.subplot(1,2,1)
plt.imshow(data[i].reshape(28, 28))
plt.subplot(1,2,2)
plt.bar(range(10), f_loc[:,i].detach(), yerr= f_var[:,i].detach())
ax = plt.gca()
ax.set_xticks(range(10))
plt.xlabel("class")
plt.savefig('image/figure'+ str(i) +'.png')
plt.clf()
The blue bar is the mean and the error bar is the variance. It was found that each of the 10 classes had an output, and the correct class output a high value.
Let's also look at images that are difficult to distinguish. The output is high in multiple classes. Considering the error bars, it seems that there is no significant difference.
I was able to learn like normal Deep Learning. I think that it is an advantage that normal Deep Learnig does not have that it can output mean value and variance as output. There is also a deep Gaussian process (DGP), which is a stack of Gaussian processes, so I would like to study that as well.
Recommended Posts