Deep Learning Java from scratch 6.3 Batch Normalization

table of contents

6.3 Evaluation of Batch Normalization

Layer Extended to be able to pass train_flg to the interface BatchNormLayer has been created.

Layer.java


public interface Layer {

    INDArray forward(INDArray x);
    INDArray backward(INDArray x);

}

BatchNormLayer.java


public interface BatchNormLayer extends Layer {

    public default INDArray forward(INDArray x) {
        throw new IllegalAccessError();
    }

    INDArray forward(INDArray x, boolean train_flg);

}

The rest is the implementation of this interface Dropout class and Batch Normalization Implemented the class. In addition, implement the MultiLayerNetExtend class and execute the following test code. To do.

INDArray x_train;
INDArray t_train;
int max_epochs = 20;
int train_size;
int batch_size = 100;
double learning_rate = 0.01;
DataSet trainDataSet;

List<List<Double>> __train(String weight_init_std) {
    MultiLayerNetExtend bn_network = new MultiLayerNetExtend(
        784, new int[] {100, 100, 100, 100, 100}, 10,
        /*activation=*/"relu",
        /*weight_init_std=*/ weight_init_std,
        /*weight_decay_lambda=*/ 0,
        /*use_dropout=*/ false,
        /*dropout_ration=*/ 0.5,
        /*use_batchNorm=*/ true);
    MultiLayerNetExtend network = new MultiLayerNetExtend(
        784, new int[] {100, 100, 100, 100, 100}, 10,
        /*activation=*/"relu",
        /*weight_init_std=*/ weight_init_std,
        /*weight_decay_lambda=*/ 0,
        /*use_dropout=*/ false,
        /*dropout_ration=*/ 0.5,
        /*use_batchNorm=*/ false);
    List<MultiLayerNetExtend> networks = Arrays.asList(bn_network, network);
    Optimizer optimizer = new SGD(learning_rate);
    List<Double> train_acc_list = new ArrayList<>();
    List<Double> bn_train_acc_lsit = new ArrayList<>();
    int iter_per_epoch = Math.max(train_size / batch_size, 1);
    int epoch_cnt = 0;
    for (int i = 0; i < 1000000000; ++i) {
        DataSet sample = trainDataSet.sample(batch_size);
        INDArray x_batch = sample.getFeatureMatrix();
        INDArray t_batch = sample.getLabels();
        for (MultiLayerNetExtend _network : networks) {
            Params grads = _network.gradient(x_batch, t_batch);
            optimizer.update(_network.params, grads);
        }
        if (i % iter_per_epoch == 0) {
            double train_acc = network.accuracy(x_train, t_train);
            double bn_train_acc = bn_network.accuracy(x_train, t_train);
            train_acc_list.add(train_acc);
            bn_train_acc_lsit.add(bn_train_acc);
            System.out.println("epoch:" + epoch_cnt + " | " + train_acc + " - " + bn_train_acc);
            ++epoch_cnt;
            if (epoch_cnt >= max_epochs)
                break;
        }
    }
    return Arrays.asList(train_acc_list, bn_train_acc_lsit);
}

@Test
public void C6_3_2_Batch_Evaluation of Normalization() throws IOException {
    // ch06/batch_norm_test.Java version of py.
    MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
    x_train = train.normalizedImages();
    t_train = train.oneHotLabels();
    trainDataSet = new DataSet(x_train, t_train);
    train_size = x_train.size(0);

    //Drawing a graph
    File dir = Constants.WeightImages;
    if (!dir.exists()) dir.mkdirs();
    String[] names = {"BatchNormalization", "Normal"};
    Color[] colors = {Color.BLUE, Color.RED};
    INDArray weight_scale_list = Functions.logspace(0, -4, 16);
    INDArray x = Functions.arrange(max_epochs);
    for (int i = 0; i < weight_scale_list.length(); ++i) {
        System.out.println( "============== " + (i+1) + "/16" + " ==============");
        double w = weight_scale_list.getDouble(i);
        List<List<Double>> acc_list = __train(String.valueOf(w));
        GraphImage graph = new GraphImage(640, 480, -1, -0.1, 20, 1.0);
        for (int j = 0; j < names.length; ++j) {
            graph.color(colors[j]);
            graph.textInt(names[j] + " : " + w, 20, 20 * j + 20);
            graph.plot(0, acc_list.get(j).get(0));
            for (int k = 1; k < acc_list.get(j).size(); ++k) {
                graph.line(k - 1, acc_list.get(j).get(k - 1), k, acc_list.get(j).get(k));
                graph.plot(k, acc_list.get(j).get(k));
            }
        }
        File file = new File(dir, "BatchNormalization#" + w + ".png ");
        graph.writeTo(file);
    }
}

The resulting graph looks like this:

W=1.0 BatchNormalization#1.0.png

W=0.29286444187164307 BatchNormalization#0.29286444187164307.png

W=0.00009999999747378752 BatchNormalization#9.999999747378752E-5.png

Unlike this book, the smaller the W (standard deviation of the initial weight), the faster the learning progresses.

Recommended Posts

Deep Learning Java from scratch 6.3 Batch Normalization
Deep Learning Java from scratch 6.4 Regularization
Study Deep Learning from scratch in Java.
Deep Learning Java from scratch Chapter 1 Introduction
Deep Learning Java from scratch 6.1 Parameter update
Deep Learning Java from scratch Chapter 2 Perceptron
Deep Learning Java from scratch Chapter 3 Neural networks
Deep Learning Java from scratch 6.2 Initial values of weights
Deep Learning Java from scratch Chapter 5 Error back propagation method
Fastest PC setup for deep learning from scratch
[Deep Learning from scratch] 2. There is no such thing as NumPy in Java.
Java life starting from scratch
[Deep Learning from scratch] in Java 1. For the time being, differentiation and partial differentiation
Run a batch file from Java
Java learning (0)
Java scratch scratch
First steps for deep learning in Java
Java learning day 5
I tried to implement deep learning in Java
java learning day 2
java learning day 1
Build VS Code + WSL + Java + Gradle environment from scratch
[Note] Create a java environment from scratch with docker
Quick learning Java "Introduction?" Part 3 Talking away from programming
Call Java from JRuby
Changes from Java 8 to Java 11
Sum from Java_1 to 100
java learning (conditional expression)
Java learning memo (method)
Eval Java source from Java
JAVA learning history interface
Access API.AI from Java
From Java to Ruby !!
Java learning memo (interface)
Java learning memo (inheritance)
Object-oriented child !? I tried Deep Learning in Java (trial edition)
Let's touch on Deep Java Library (DJL), a library that can handle Deep Learning in Java released from AWS.