Implemented continuous learning using Mahalanobis distance in feature space

Continuous learning

Continuous learning means that the model continuously learns new data given one after another for a long period of time. For details, refer to Slide that summarizes continuous learning.

This time, we will implement a paper on continuous learning using the method of out-of-distribution detection. Paper title: A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks Outline of the paper: The out-of-distribution data is regarded as the data of the new class, Gaussian fitting is performed on the feature space of the deep model, and the test data is classified based on the Mahalanobis distance between the mean vector of the old class and the new class. The detailed explanation of the paper was written in the second part of Slide.

The strength of this paper is that ** continuous learning is possible without reducing the accuracy of the classes learned so far **, and it is necessary to relearn the parameters of DNN using the data of the newly added class. There is no **. Therefore, even if the number of classes increases, learning can be performed very quickly.

Implementation problem setting

--The 0th to 4th classes of CIFAR10 are given as training data first, and the training data of additional 5th to 9th classes are given one after another. --It is not possible to learn using only the data of the additionally given class and re-learn using the data of the first given class.

Implementation explanation

First preparation

import os 

import numpy as np
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms as T
from torchvision.datasets import CIFAR10

# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
#Put model and utils in the above repository in the same hierarchy
from model import EfficientNet 
from tqdm import tqdm
plt.style.use("ggplot")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

I prepared a function that returns a dataloader so that I can add a CIFAR10 class

def return_data_loader(classes, train=True, batch_size=128):
    transform = []
    transform.append(T.Resize((64, 64))) #Need to resize to use efmodel
    transform.append(T.ToTensor())
    transform = T.Compose(transform)

    dataset = CIFAR10("./data", train=train, download=True, transform=transform)
    targets = np.array(dataset.targets)
    mask = np.array([t in classes for t in targets])
    dataset.data = dataset.data[mask]
    dataset.targets = targets[mask]
    
    data_loader = DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=train)
    return data_loader

Use the smallest model of efficientnet as a model

First, as the blue arrow part in the figure, we usually learn the discriminative model of 5 classes. Next, as the red arrow part, the previous feature is approximated by Gaussian distribution for each class. スクリーンショット 2020-10-05 9.08.56.png

NCLASS = 5 #Initial class
classes = np.arange(NCLASS)
model = 'efficientnet-b0'
weight_dir = "."

clf = EfficientNet.from_name(model)
clf._fc = torch.nn.Linear(clf._fc.in_features, NCLASS)
clf = clf.to(device)
clf.train()
train_loader = return_data_loader(classes=classes, train=True)

lr = 0.001
epoch_num = 50
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(clf.parameters(), lr=lr)

for epoch in tqdm(range(epoch_num)):
    train_loss = 0
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        logit = clf(x)
        loss = criterion(logit, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)
torch.save(clf.state_dict(), os.path.join(weight_dir, 'weight.pth'))

test_loader = return_data_loader(range(10), train=False)
clf.load_state_dict(torch.load(os.path.join(weight_dir, 'weight.pth')))
clf.eval()

pred = []
true = []
for x, y in test_loader:
    with torch.no_grad():
            pred.extend(clf(x.to(device)).max(1)[1].detach().cpu().numpy())
            true.extend(y.numpy())
print(accuracy_score(true, pred))
print(confusion_matrix(true, pred))

First, we output the mixed matrix and the accuracy rate when the discriminative model is normally trained. Since only the 0th to 4th classes are used for learning, it is naturally impossible to predict the 5th to 9th classes, and the 0th to 4th classes are forcibly predicted.

0.4279 #Correct answer rate
[[877  27  47  35  14   0   0   0   0   0]
 [ 14 972   3   8   3   0   0   0   0   0]
 [ 51   7 785  81  76   0   0   0   0   0]
 [ 20  18 107 780  75   0   0   0   0   0]
 [ 13   2  58  62 865   0   0   0   0   0]
 [ 13  12 226 640 109   0   0   0   0   0]
 [ 26  55 232 477 210   0   0   0   0   0]
 [ 47  21 188 230 514   0   0   0   0   0]
 [604 214  53  95  34   0   0   0   0   0]
 [160 705  43  78  14   0   0   0   0   0]]

Next, for the implementation of the red arrow part in the above figure, calculate the mean and covariance of the features. スクリーンショット 2020-10-05 9.29.18.png

def ext_feature(x):
    z = clf.extract_features(x)
    z = clf._avg_pooling(z)
    z = z.flatten(start_dim=1)
    return z.detach().cpu().numpy()

train_loaders = [return_data_loader(classes=[c], train=True) for c in range(10)]

z_mean = []
z_var = 0
target_count = []

for c in tqdm(range(NCLASS)): #Existing class
    N = len(train_loaders[c].dataset) #Holding the number of each class
    target_count.append(N)
    
    with torch.no_grad():
        #Average calculation
        new_z_mean = 0
        for x, _ in train_loaders[c]:
            x = x.to(device)
            new_z_mean += ext_feature(x).sum(0) / N
        z_mean.append(new_z_mean)

        #Variance calculation
        for x, _ in train_loaders[c]:
            x = x.to(device)    
            z_var += (ext_feature(x) - new_z_mean).T.dot(ext_feature(x) - new_z_mean) / N

C = len(z_mean)
z_var /=  C
z_mean = np.array(z_mean)
target_count = np.array(target_count)

Once the mean and covariance are obtained, it is possible to classify without the fully connected layer of the final layer using the Mahalanobis distance. The implementation uses Bayes' theorem スクリーンショット 2020-10-05 9.32.07.png Where $ \ beta_c $ is the number of class data

z_var_inv = np.linalg.inv(z_var + np.eye(z_mean.shape[1])*1e-6)  
#Add regularization to prevent the inverse matrix from becoming unstable
A = z_mean.dot(z_var_inv) #One item of the contents of the numerator exp
B = (A*z_mean).sum(1) * 0.5 #2 items
beta = np.log(target_count) #3 items

accs = []
pred = []
true = []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        pred.extend((A.dot(ext_feature(x).T) - B[:, None] + beta[:, None]).argmax(0))
        true.extend(y.numpy())
acc = accuracy_score(true, pred)        
print(acc)
accs.append(acc)
confusion_matrix(true, pred)

From the following results, it was found that almost the same accuracy rate can be achieved without using the fully connected layer.

0.4273 #Correct answer rate
array([[899,  17,  43,  29,  12,   0,   0,   0,   0,   0],
       [ 25, 958,   6,   9,   2,   0,   0,   0,   0,   0],
       [ 55,   6, 785,  86,  68,   0,   0,   0,   0,   0],
       [ 29,  15, 109, 773,  74,   0,   0,   0,   0,   0],
       [ 23,   2,  55,  62, 858,   0,   0,   0,   0,   0],
       [ 22,   6, 227, 641, 104,   0,   0,   0,   0,   0],
       [ 34,  39, 256, 468, 203,   0,   0,   0,   0,   0],
       [ 71,  16, 199, 214, 500,   0,   0,   0,   0,   0],
       [653, 182,  53,  84,  28,   0,   0,   0,   0,   0],
       [221, 645,  42,  78,  14,   0,   0,   0,   0,   0]])

Implementation of continuous learning

The goal is to classify the test data for all classes based on the mean and variance of the new data, without training the model parameters.

The outline of the algorithm is as follows スクリーンショット 2020-10-05 9.39.07.png

for c in tqdm(range(NCLASS, 10)): #New class
    N = len(train_loaders[c].dataset)

    with torch.no_grad():
        #Average calculation
        new_z_mean = 0        
        for x, _ in train_loaders[c]:
            x = x.to(device)
            new_z_mean += ext_feature(x).sum(0) / N 


        #Variance calculation
        new_z_var = 0
        for x, _ in train_loaders[c]:
            x = x.to(device)    
            new_z_var += (ext_feature(x) - new_z_mean).T.dot(ext_feature(x) - new_z_mean) / N

    #Average and variance updates
    C = len(target_count)
    z_mean = np.concatenate([z_mean, new_z_mean[None, :]])
    z_var = z_var*C/(C+1) + new_z_var/(C+1)
    target_count = np.append(target_count, N)

    z_var_inv = np.linalg.inv(z_var + np.eye(z_mean.shape[1])*1e-6)
    A = z_mean.dot(z_var_inv) 
    B = (A*z_mean).sum(1) * 0.5
    beta = np.log(target_count)
    pred = []
    true = []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            pred.extend((A.dot(ext_feature(x).T) - B[:, None] + beta[:, None]).argmax(0))
            true.extend(y.numpy())
    acc = accuracy_score(true, pred)
    accs.append(acc)
    print(acc)

The final result is

0.4974 #Correct answer rate
array([[635,   1,  18,   4,   2,  14,   9,  36, 260,  21],
       [  1, 761,   0,   1,   0,   0,   8,   3,  21, 205],
       [ 20,   0, 581,  12,   8,  97, 105, 135,  35,   7],
       [  5,   0,  22, 450,  13, 256, 147,  60,  29,  18],
       [  2,   1,  16,  10, 555,  30,  63, 302,  20,   1],
       [  1,   0,  57, 288,  22, 325, 173, 106,  22,   6],
       [  5,   0,  49, 139,  36, 182, 350, 161,  35,  43],
       [  5,   2,  34,  50, 131, 104, 158, 446,  58,  12],
       [226,  26,  13,  11,   3,  22,  58,  41, 430, 170],
       [ 17, 250,   6,   5,   0,   8,  69,  16, 188, 441]])
plt.title("accuracy")
plt.plot(accs)
plt.show()

The x-axis means the number of classes added. It can be seen that the correct answer rate when 10 classes of training data are finally given is about 0.1 higher than when only 5 classes are given. download-6.png

Recommended Posts

Implemented continuous learning using Mahalanobis distance in feature space
Calculate Mahalanobis distance considering feature correlation using scipy
Widrow-Hoff learning rules implemented in Python
Implemented Perceptron learning rules in Python
Image recognition model using deep learning in 2016
Data supply tricks using deques in machine learning