Deep Learning Java from scratch 6.4 Regularization

table of contents

6.4.1 Overfitting

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Load damping)settings of===========
double weight_decay_lambda = 0; //When not using weight decay
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.Drawing a graph=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Recognition accuracy in overfitting", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit.png "));

overfit.png

6.4.2 Weidht decay

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Load damping)settings of===========
// weight_decay_lambda = 0 //When not using weight decay
double weight_decay_lambda = 0.1;
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.Drawing a graph=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Recognition accuracy in overfitting using Weight decay", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit_weight_decay.png "));

overfit_weight_decay.png

6.4.3 Dropout


MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
//Reduce training data to reproduce overfitting
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
//Presence / absence of Dropout, setting of ratio====================
boolean use_dropout = true; //False when there is no Dropout
double dropout_ratio = 0.2;
// =============================================
MultiLayerNetExtend network = new MultiLayerNetExtend(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu",
    /*weight_init_std*/"relu",
    /*weight_decay_lambda*/0,
    /*use_dropout*/use_dropout, /*dropout_ratio*/dropout_ratio,
    /*use_bachnorm*/false);
Trainer trainer = new Trainer(network, x_train, t_train, x_test, t_test,
    /*epochs*/301,
    /*mini_batch_size*/100,
    /*optimizer*/() -> new SGD(0.01),
    /*evaluate_sample_num_per_epoch*/0,
    /*verbose*/true);

trainer.train();
List<Double> train_acc_list = trainer.train_acc_list;
List<Double> test_acc_list = trainer.test_acc_list;

// 3.Drawing a graph=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Recognition accuracy in Dropout", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "dropout.png "));

dropout.png

Recommended Posts

Deep Learning Java from scratch 6.4 Regularization
Study Deep Learning from scratch in Java.
Deep Learning Java from scratch Chapter 1 Introduction
Deep Learning Java from scratch 6.1 Parameter update
Deep Learning Java from scratch Chapter 2 Perceptron
Deep Learning Java from scratch 6.3 Batch Normalization
Deep Learning from scratch Java Chapter 4 Neural network learning
[Deep Learning from scratch] in Java 3. Neural network
Deep Learning Java from scratch Chapter 3 Neural networks
Deep Learning Java from scratch 6.2 Initial values of weights
Deep Learning Java from scratch Chapter 5 Error back propagation method
Fastest PC setup for deep learning from scratch
Java life starting from scratch
[Deep Learning from scratch] 2. There is no such thing as NumPy in Java.
[Deep Learning from scratch] in Java 1. For the time being, differentiation and partial differentiation
Java scratch scratch
First steps for deep learning in Java
For JAVA learning (2018-03-16-01)
Java learning day 5
java learning day 2
java learning day 1
I tried to implement deep learning in Java
Call Java from JRuby
Build VS Code + WSL + Java + Gradle environment from scratch
Changes from Java 8 to Java 11
Sum from Java_1 to 100
Java learning 2 (learning calculation method)
java learning (conditional expression)
Java learning memo (method)
Eval Java source from Java
Java Learning (1)-Hello World
JAVA learning history interface
Access API.AI from Java
Java learning memo (basic)
From Java to Ruby !!
[Note] Create a java environment from scratch with docker
Java learning memo (interface)
Java learning memo (inheritance)
Quick learning Java "Introduction?" Part 3 Talking away from programming
Object-oriented child !? I tried Deep Learning in Java (trial edition)
Learning Java framework # 1 (Mac version)
Java basic learning content 7 (exception)
Migration from Cobol to JAVA
JAVA learning history interface inheritance
Java learning memo (data type)
Java starting from beginner, override
Creating ElasticSearch index from Java
Java basic learning content 5 (modifier)
New features from Java7 to Java8
Books used for learning Java
4th day of java learning
Connect from Java to PostgreSQL
Java learning memo (logical operator)
Java, instance starting from beginner
Java learning memo (abstract class)
Java starting from beginner, inheritance
Let's touch on Deep Java Library (DJL), a library that can handle Deep Learning in Java released from AWS.
Using Docker from Java Gradle
From Ineffective Java to Effective Java
JavaScript as seen from Java
Java Basic Learning Content 8 (Java API)