Is Cutmix valid for table data as well?

at first

Supervised learning usually requires a sufficient amount of labeled data to achieve high accuracy. However, manual annotation requires a great deal of time and effort. One way to solve this is data augmentation, which artificially bulks the data.

However, data augmentation is often talked about with images, and there are not many methods that can be applied to table data. Therefore, this article introduces data augmentation that can be applied to table data, conducts experiments, and verifies their performance.

Mixup

mixup: Beyond Empirical Risk Minimization

Mixup is a method proposed in 2017 and adopted by ICLR. A new input is generated by mixing the two inputs.

import random as rn

from sklearn.utils import check_random_state


def mixup(x, y=None, alpha=0.2, p=1.0, random_state=None):
    n, _ = x.shape

    if n is not None and rn.random() < p:
        random_state = check_random_state(random_state)
        l = random_state.beta(alpha, alpha)
        shuffle = random_state.choice(n, n, replace=False)

        x = l * x + (1.0 - l) * x[shuffle]

        if y is not None:
            y = l * y + (1.0 - l) * y[shuffle]

    return x, y

It has been reported in the paper that the performance was improved by applying Mixup to audio and table data in addition to images.

Cutmix

CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features

Cutmix is ​​a method proposed in 2019 and adopted by ICCV. Generate a new input by replacing one part of the input with the other.

import random as rn

import numpy as np
from sklearn.utils import check_random_state


def cutmix(x, y=None, alpha=1.0, p=1.0, random_state=None):
    n, h, w, _ = x.shape

    if n is not None and rn.random() < p:
        random_state = check_random_state(random_state)
        l = np.random.beta(alpha, alpha)
        r_h = int(h * np.sqrt(1.0 - l))
        r_w = int(w * np.sqrt(1.0 - l))
        x1 = np.random.randint(h - r_h)
        y1 = np.random.randint(w - r_w)
        x2 = x1 + r_h
        y2 = y1 + r_w
        shuffle = random_state.choice(n, n, replace=False)

        x[:, x1:x2, y1:y2] = x[shuffle, x1:x2, y1:y2]

        if y is not None:
            y = l * y + (1.0 - l) * y[shuffle]

    return x, y

Only the results of applying Cutmix to images are reported in the paper. What happens if this is applied to table data?

In the table data, the order of features (age, nationality, etc.) is meaningless. Therefore, I will randomly select the part to be replaced with the other input.

import random as rn

import numpy as np
from sklearn.utils import check_random_state


def cutmix_for_tabular(x, y=None, alpha=1.0, p=1.0, random_state=None):
    n, d = x.shape

    if n is not None and rn.random() < p:
        random_state = check_random_state(random_state)
        l = random_state.beta(alpha, alpha)
        mask = random_state.choice([False, True], size=d, p=[l, 1.0 - l])
        mask = np.where(mask)[0]
        shuffle = random_state.choice(n, n, replace=False)

        x[:, mask] = x[shuffle, mask]

        if y is not None:
            y = l * y + (1.0 - l) * y[shuffle]

    return x, y

Experiment

This time, we will conduct an experiment using the following data. This is a multi-label classification problem that predicts the mechanism of action of compounds from gene expression patterns.

Mechanisms of Action (MoA) Prediction | Kaggle

Check the code below for the details of the experiment.

Logloss looks like this:

Local Public Private
Baseline 0.01604 0.01906 0.01666
Mixup 0.01605 0.01905 0.01668
Cutmix 0.01604 0.01901 0.01663

It was confirmed that both Public and Private score improved with Cutmix.

At the end

Cutmix is ​​also an effective method for table data.

Finally, we have released the solution that ranked 35th using Cutmix in the above competition, so if you are interested, please have a look.

Mechanisms of Action (MoA) Prediction | Kaggle

image.png

Recommended Posts

Is Cutmix valid for table data as well?
xgboost: A valid machine learning model for table data