I tried to implement deep learning in Java

"Deep learning from scratch" is a well-known book on the approach of implementing and understanding deep learning from scratch, but the book "Deep learning mathematics" with the same purpose is very popular. It was easy to understand, so I tried Python code Implemented in Java.

Java is a barren machine learning zone (or is the data science area Python's monopoly?), But if you have a strange person who wants to understand the logic of machine learning in Java, please refer to it.

Like the original, I implemented it by scratch without using the library as much as possible. (The commons math library is used for the calculation part of matrices and vectors due to the problem of execution speed) Also, because it matches the content of the book, there are some parts that are not Java-like and are redundant.

Introduction of the book "Deep Learning Mathematics"

First of all, I would like to briefly introduce the contents of the book.

The book is divided into four parts, the introductory part is the basics of machine learning, the theory part is mathematics, the implementation part is the implementation of machine learning algorithms by Python, and the advanced part is a mathematical or conceptual level explanation with slightly advanced contents (no implementation). ).

I think the features of this book are the following three points.

  1. The dependency between chapters is visualized, making it easy to get a bird's eye view of the whole.
  2. Mathematics not only enumerates formulas, but also derives formulas exactly. (There are some parts that are not strict, but I think it is a good balance because it is not a specialized book on mathematics.)
  3. The implementation starts with a simple linear regression and gradually evolves into deep learning, so you can gradually deepen your understanding just by understanding the differences at each stage. I think 3 is an important point in particular.

The outline of each volume is as follows.

Introduction

The introductory part is only one chapter, which explains the outline and mechanism of machine learning. Also, as an example, the regression problem of predicting weight from height is analytically solved by the formula transformation called square completion.

Theory

The theory section mainly explains mathematics, but it is narrowed down to the extent necessary for deep learning. There are chapters 2 to 6, chapter 2 is differentiation / integration, chapter 3 is vector / matrix, chapter 4 is differentiation (partial differentiation) of multivariable functions, chapter 5 is exponential function / logarithmic function, and chapter 6 is probability / statistics. It has become. Is the level from high school mathematics to liberal arts subjects in the university science department?

The chapter structure is well-developed, and if you read it from the front, you will be able to understand it without contradiction.

I won't explain mathematics in this article, but if you want a deeper understanding of how deep learning works, you should read the theory and then work on the implementation part in the second half.

Implementation

In the practical version, the logic of machine learning is implemented in Python. There are Chapters 7 to 10, Chapter 7 is a linear regression model (regression problem), Chapter 8 is a logistic regression model (binary classification problem), Chapter 9 is a logistic regression model (multi-value classification problem), and Chapter 10 is deep. Implement learning.

The implementation section also has a devised chapter structure, and if you add one or two concepts to the previous chapter, the logic of the next chapter can be realized, so you can understand it step by step. Also, the difference between each chapter is summarized in a table format in the spread at the beginning, so please take a look if you get confused.

Development

There are only 11 chapters in the development edition, CNN model that is strong in images, RNN model that is strong in time series data, the idea of ​​numerical differentiation that finds the gradient by numerical calculation, optimization algorithm that is more efficient than gradient descent method, overlearning measures, weight matrix It explains how to initialize the above at the mathematical formula and conceptual level.

Java implementation of linear regression model (regression problem)

Let's implement it in Java from here. The linear regression model uses real estate / regional data called The Boston Housing Dataset to predict real estate prices.

Linear simple regression model

The first implements a linear simple regression model that uses one variable with only the number of rooms (RM).

The constructor mainly performs the following processing. ・ Set hyperparameters (number of learnings and learning rate) ・ Create learning data and correct answer data -Initialize the weight vector with 1

In the learn method, the following three processes are repeated for learning. ・ Calculate the predicted value ・ Calculate the error ・ Multiply the gradient by the learning rate to update the weight

LinearSingleRegression.java


package math.deeplearning.ch07;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 *Linear simple regression model.
 */
public class LinearSingleRegression {
    //Learning rate
    private double alpha;
    //Number of learning
    private int iters;
    //Training data
    private RealMatrix x;
    //Correct answer data
    private RealVector yt;
    //Number of input data lines
    private int M;
    //Number of input data columns
    private int D;
    //Weight vector
    private RealVector W;

    /**
     *Initialization process.
     *
     * @param iters Number of learnings
     * @param alpha Learning rate
     */
    public LinearSingleRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        //Load The Boston Housing Dataset
        RealMatrix boston = loadBoston();
        //Number of rooms as learning data(RM)Extract columns and add dummy variable 1
        x = addBiasCol(extractCol(boston, new int[]{5}));
        //Extract the property price as correct answer data
        yt = boston.getColumnVector(13);
        //Number of lines of training data
        M = x.getRowDimension();
        //Number of columns of training data
        D = x.getColumnDimension();

        //Initialize the weight vector with 1
        W = add(MatrixUtils.createRealVector(new double[D]), 1.0);
    }

    public static void main(String[] args) throws Exception {
        //5000 learning times, 0 learning rate.Set to 01
        LinearSingleRegression lsr = new LinearSingleRegression(50000, 0.01);
        //learn
        lsr.learn();
    }

    /**
     *learn.
     */
    public void learn() {
        for (int i = 0; i < iters; i++) {
            //Calculate the predicted value yp
            RealVector yp = dot(x, W);
            //Calculate error yd
            RealVector yd = sub(yp, yt);
            //Update the weight by multiplying the gradient by the learning rate
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            //Display the error after learning a certain number of times
            if (i % 100 == 0)
                System.out.println(i + " " + mean(pow(yd, 2)) / 2);
        }
    }
}

As a result of packing the troublesome processing into the math.deeplearning.common.Util class, the logic part of machine learning could be implemented somewhat simply. When executed, you can see that the error becomes smaller as the learning progresses.

Output of linear simple regression model


0 154.2249338409091
100 29.617518011568446
・ ・ ・
49800 21.80032626850963
49900 21.800325071320316

Linear multiple regression model

Next, we will implement a linear multiple regression model that uses two variables, the number of rooms (RM) and the low income rate (LSTAT).

LinearMultipleRegression.java


package math.deeplearning.ch07;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 *Linear multiple regression model.
 */
public class LinearMultipleRegression {
    //Learning rate
    private double alpha;
    //Number of learning
    private int iters;
    //Training data
    private RealMatrix x;
    //Correct answer data
    private RealVector yt;
    //Number of input data lines
    private int M;
    //Number of input data columns
    private int D;
    //Weight vector
    private RealVector W;

    /**
     *Initialization process.
     *
     * @param iters Number of learnings
     * @param alpha Learning rate
     */
    public LinearMultipleRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        //Load The Boston Housing Dataset
        RealMatrix boston = loadBoston();
        //Number of rooms as learning data(RM)Columns and low-income rates(LSTAT)Extract columns and add dummy variable 1
        x = addBiasCol(extractCol(boston, new int[]{5, 12}));
        //Extract the property price as correct answer data
        yt = boston.getColumnVector(13);
        //Number of lines of training data
        M = x.getRowDimension();
        //Number of columns of training data
        D = x.getColumnDimension();

        //Initialize the weight vector with 1
        W = add(MatrixUtils.createRealVector(new double[D]), 1.0);
    }

    public static void main(String[] args) throws Exception {
        //2000 learning times, 0 learning rate.Set to 001
        LinearMultipleRegression lmr = new LinearMultipleRegression(2000, 0.001);
        //learn
        lmr.learn();
    }

    /**
     *learn.
     */
    public void learn() {
        for (int i = 0; i < iters; i++) {
            //Calculate the predicted value yp
            RealVector yp = dot(x, W);
            //Calculate error yd
            RealVector yd = sub(yp, yt);
            //Update the weight by multiplying the gradient by the learning rate
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            //Display the error after learning a certain number of times
            if (i % 100 == 0)
                System.out.println(i + " " + mean(pow(yd, 2)) / 2);
        }
    }
}

By adding one feature, the error is smaller than that of the linear simple regression model, and the learning converges faster.

Output of linear multiple regression model


0 112.06398160770748
100 25.358934200838444
・ ・ ・
1800 15.280256759397282
1900 15.280228371672587

Comparing with the code of the linear simple regression model, it can be seen that the actual difference is only the part that extracts the following training data, and it is possible to correspond to the linear multiple regression without changing the logic part of the linear regression.

Linear simple regression model


//Number of rooms as learning data(RM)Extract columns and add dummy variable 1
x = addBiasCol(extractCol(boston, new int[]{5}));

Linear multiple regression model


//Number of rooms as learning data(RM)Columns and low-income rates(LSTAT)Extract columns and add dummy variable 1
x = addBiasCol(extractCol(boston, new int[]{5, 12}));

Java implementation of logistic regression model (binary classification problem)

Next, we will implement a logistic regression model that classifies iris types into two classes using iris size data called Iris Data Set. In the Iris Data Set, the 50th data from the beginning is Setosa, and the 51st to 100th data are Versicolour, which is the type of iris data, so extract the 100 data from the beginning and shuffle the order. I will.

python


//Read two types of iris data, Setosa and Versicolour, from the Iris Data Set
RealMatrix iris = shuffle(loadIris(0, 100));

The processing flow is the same as linear regression, but this time we will solve the binary classification problem with a logistic regression model, so we will put the sigmoid function in the activation function and convert the output from 0 to 1 probability value.

Binary logistic regression model

BinaryLogisticRegression.java


package math.deeplearning.ch08;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 *Logistic regression model(Binary classification).
 */
public class BinaryLogisticRegression {
    //Learning rate
    private double alpha;
    //Number of learning
    private int iters;
    //Training data
    private RealMatrix x;
    //Evaluation training data
    private RealMatrix xTest;
    //Correct answer data
    private RealVector yt;
    //Correct answer data for evaluation
    private RealVector ytTest;
    //Number of input data lines
    private int M;
    //Number of input data columns
    private int D;
    //Weight vector
    private RealVector W;

    /**
     *Initialization process.
     *
     * @param iters Number of learnings
     * @param alpha Learning rate
     */
    public BinaryLogisticRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        //Read two types of iris data, Setosa and Versicolour, from the Iris Data Set
        RealMatrix iris = shuffle(loadIris(0, 100));
        //Extract the sepal length column and sepal width column as training data, and add dummy variable 1.
        x = addBiasCol(extractRowCol(iris, 0, 69, 0, 1));
        //Extract the sepal length column and sepal width column as test data, and add dummy variable 1.
        xTest = addBiasCol(extractRowCol(iris, 70, 99, 0, 1));
        //Extract the type of iris as correct learning data
        yt = extractRowCol(iris, 0, 69, 4);
        //Extract the type of iris as the correct answer data for the test
        ytTest = extractRowCol(iris, 70, 99, 4);
        //Number of lines of correct data
        M = x.getRowDimension();
        //Number of columns of correct data
        D = x.getColumnDimension();

        //Initialize the weight vector with 1
        W = add(MatrixUtils.createRealVector(new double[D]), 1.0);
    }

    public static void main(String[] args) throws Exception {
        //Learning times 10000, learning rate 0.Set to 01
        BinaryLogisticRegression blr = new BinaryLogisticRegression(10000, 0.01);
        //learn
        blr.learn();
    }

    /**
     *learn.
     */
    public void learn() {
        //learn
        for (int i = 0; i < iters; i++) {
            //Calculate the predicted value yp
            RealVector yp = sigmoid(dot(x, W));
            //Calculate error yd
            RealVector yd = sub(yp, yt);
            //Update the weight by multiplying the gradient by the learning rate
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            //Display error and accuracy after learning a certain number of times
            if (i % 10 == 0) {
                RealVector p = sigmoid(dot(xTest, W));
                System.out.print("iter = " + i + "\tloss = " + crossEntropy(ytTest, p));
                System.out.println("\tscore = " + calcAccuracy(ytTest, p));
            }
        }
    }
}

At first, the correct answer rate was around 50%, but in the end it became 100%.

Output of binary logistic regression model


iter = 0	loss = 4.401398657630698	score = 0.5333333333333333
iter = 10	loss = 3.4820950219350593	score = 0.5333333333333333
・ ・ ・
iter = 9980	loss = 0.10275578614196225	score = 1.0
iter = 9990	loss = 0.10270185332637241	score = 1.0

Comparing with the linear regression code, we can see that the only substantial difference is the part where the sigmoid function is added to the calculation of the predicted value yp below. (Although the training data and the test data are separated from this time, it is not an essential difference.)

Linear regression model


//Calculate the predicted value yp
RealVector yp = dot(x, W);

Logistic regression model


//Calculate the predicted value yp
RealVector yp = sigmoid(dot(x, W));

Java implementation of logistic regression model (multi-value classification problem)

Then implement the same Iris Data Set logistic regression model that uses all the iris data to classify the iris types into 3 classes.

Since the problem setting changes from 2 class classification to 3 class classification, change the activation function from the sigmoid function to the softmax function.

Multi-value logistic regression model

MultipleLogisticRegression.java


package math.deeplearning.ch09;

import org.apache.commons.math3.linear.*;
import java.io.IOException;
import static math.deeplearning.common.Util.*;

/**
 *Logistic regression model(Multi-value classification).
 */
public class MultipleLogisticRegression {
    //Learning rate
    private double alpha;
    //Number of learning
    private int iters;
    //Training data
    private RealMatrix x;
    //Evaluation training data
    private RealMatrix xTest;
    //Correct answer data
    private RealMatrix yt;
    //Correct answer data for evaluation
    private RealMatrix ytTest;
    //Number of input data lines
    private int M;
    //Number of input data columns
    private int D;
    //Weight matrix
    private RealMatrix W;

    /**
     *Initialization process.
     *
     * @param iters Number of learnings
     * @param alpha Learning rate
     */
    public MultipleLogisticRegression(int iters, double alpha) throws IOException {
        this.iters = iters;
        this.alpha = alpha;

        //Read two types of iris data, Setosa and Versicolour, from the Iris Data Set
        RealMatrix iris = shuffle(loadIris());
        //Extract the sepal length column and petal length column as training data, and add dummy variable 1.
        x = addBiasCol(extractRowCol(iris, 0, 74, new int[]{0, 2}));
        //Extract the sepal length column and petal length column as test data, and add dummy variable 1.
        xTest = addBiasCol(extractRowCol(iris, 75, 149, new int[]{0, 2}));
        //When using all 4 variables
        // x = addBiasCol(extractRowCol(iris, 0, 74, 0, 3));
        // xTest = addBiasCol(extractRowCol(iris, 75, 149, 0, 3));
        //Extract the type of iris as the correct answer data for learning and convert it to OneHotVector format
        yt = oneHotEncode(extractRowCol(iris, 0, 74, 4), 3);
        //Extract the type of iris as the correct answer data of the test and convert it to OneHotVector format
        ytTest = oneHotEncode(extractRowCol(iris, 75, 149, 4), 3);
        //Number of lines of correct data
        M = x.getRowDimension();
        //Number of columns of correct data
        D = x.getColumnDimension();

        //Initialize the weight matrix with 1
        W = add(MatrixUtils.createRealMatrix(D, 3), 1.0);
    }

    public static void main(String[] args) throws Exception {
        //Learning times 10000, learning rate 0.Set to 01
        MultipleLogisticRegression mlr = new MultipleLogisticRegression(10000, 0.01);
        mlr.learn();
    }

    /**
     *learn.
     */
    public void learn() {
        //learn
        for (int i = 0; i < iters; i++) {
            //Calculate the predicted value yp
            RealMatrix yp = softmax(dot(x, W));
            //Calculate error yd
            RealMatrix yd = sub(yp, yt);
            //Update the weight by multiplying the gradient by the learning rate
            W = sub(W, mult(div(dot(t(x), yd), M), alpha));

            //Display error and accuracy after learning a certain number of times
            if (i % 10 == 0) {
                RealMatrix p = softmax(dot(xTest, W));
                System.out.print("iter = " + i + "\tloss = " + crossEntropy(ytTest, p));
                System.out.println("\tscore = " + calcAccuracy(ytTest, p));
            }
        }
    }
}

The initial accuracy rate was around 1/3, but in the end it was 97%.

Output of multi-value logistic regression model


iter = 0	loss = 1.089863468306522	score = 0.30666666666666664
iter = 10	loss = 1.0505735104517255	score = 0.30666666666666664
・ ・ ・
iter = 9980	loss = 0.18412409250145656	score = 0.9733333333333334
iter = 9990	loss = 0.18403868595917505	score = 0.9733333333333334

Compared with the code of the logistic regression model of binary classification, the main differences are the following three points. -Change the weight from vector to matrix (because each class has a weight) -Changed the correct answer data from 0/1 binary format to OneHotVector format (because it changed from binary classification to multi-value classification) -Changed the activation function of the calculation of the predicted value yp from the sigmoid function to the softmax function.

Java implementation of deep learning model

Now it's time to implement the deep learning model.

Here, we will solve the problem of classifying the image data of handwritten numbers called MNIST into 10 classes from 0 to 9.

3-layer deep learning model

First, implement a 3-layer deep learning model with only one hidden layer.

DeepLearning.java


package math.deeplearning.ch10;

import org.apache.commons.math3.linear.RealMatrix;
import java.io.IOException;
import java.util.*;
import static math.deeplearning.common.Util.*;

/**
 *Deep learning(1 hidden layer).
 */
public class DeepLearning {
    //Number of lines of training data
    private int M;
    //Number of columns of training data(Number of pixels in the image)
    private int D;
    //Number of classification classes
    private int N;
    //Number of learning
    private int iters;
    //Batch data size
    private int batchSize;
    //Learning rate
    private double alpha;

    //MNIST image data
    private RealMatrix xAll;
    private RealMatrix xTest;
    private RealMatrix ytAll;
    private RealMatrix ytTest;

    //Weight matrix
    private RealMatrix V;
    private RealMatrix W;

    public DeepLearning(int iters, int H, int batchSize, double alpha) throws IOException {
        //Read the MNIST dataset
        xAll = addBiasCol(div(loadMnistImage(MNIST_TRAIN_IMAGE_FILE_NAME), 255));
        xTest = addBiasCol(div(loadMnistImage(MNIST_TEST_IMAGE_FILE_NAME), 255));
        ytAll = oneHotEncode(loadMnistLabel(MNIST_TRAIN_LABEL_FILE_NAME), 10);
        ytTest = oneHotEncode(loadMnistLabel(MNIST_TEST_LABEL_FILE_NAME), 10);

        M = xAll.getRowDimension();
        D = xAll.getColumnDimension();
        N = ytAll.getColumnDimension();

        this.iters = iters;
        this.batchSize = batchSize;
        this.alpha = alpha;

        //Initialize the weight matrix with He Normal
        V = initW(D, H);
        W = initW(H + 1, N);
    }

    public static void main(String... args) throws Exception {
        //The number of learnings is 10000, the number of neurons in the hidden layer is 128, the batch size is 512, and the learning rate is 0..Set to 01
        DeepLearning dl = new DeepLearning(10000, 128, 512, 0.01);
        //learn
        dl.learn();
    }

    public void learn() {
        //Initialize random sampling index
        List<Integer> indexes = new ArrayList<>();
        for (int i = 0; i < M; i++) indexes.add(i);

        for (int i = 0; i < iters; i++) {
            //Sampling of training data
            List<Integer> index = randIndex(indexes, M, batchSize);
            RealMatrix x = sampling(xAll, index);
            RealMatrix yt = sampling(ytAll, index);

            //Calculate the output value of each layer
            RealMatrix a = dot(x, V);
            RealMatrix b = reLU(a);
            RealMatrix b1 = addBiasCol(b);
            RealMatrix u = dot(b1, W);
            RealMatrix yp = softmax(u);
            //Calculate the error of each layer
            RealMatrix yd = sub(yp, yt);
            RealMatrix bd = mult(step(a), dot(yd, t(removeBias(W))));
            //Update the weight of each layer by multiplying the gradient by the learning rate
            W = sub(W, mult(div(dot(t(b1), yd), batchSize), alpha));
            V = sub(V, mult(div(dot(t(x), bd), batchSize), alpha));

            //Display error and accuracy after learning a certain number of times
            if (i % 100 == 0) {
                RealMatrix p = softmax(dot(addBiasCol(reLU(dot(xTest, V))), W));
                System.out.print(i + " " + crossEntropy(ytTest, p) + " ");
                System.out.println(calcAccuracy(ytTest, p));
            }
        }
    }
}

Initially, the error is around 2.3 and the correct answer rate is around 10%, but finally the error is around 0.21 and the correct answer rate is around 94%.

Output of 3-layer deep learning model


0 2.449633365625842 0.0951
100 1.5349024136564533 0.6818
・ ・ ・
9800 0.21109711296030495 0.9416
9900 0.21035221505955806 0.9419

The weight matrix is ​​incremented by 1, and the weight matrix corresponding to the input layer is V and the weight matrix corresponding to the hidden layer is W. Also, the initial value of the weight matrix is ​​not fixed to 1, but is initialized by a method called He Normal.

//Initialize the weight matrix with He Normal
V = initW(D, H);
W = initW(H + 1, N);

The hidden layer activation function uses ReLU.

//Calculate the output value of each layer
RealMatrix a = dot(x, V);
RealMatrix b = reLU(a);
RealMatrix b1 = addBiasCol(b);
RealMatrix u = dot(b1, W);
RealMatrix yp = softmax(u);

The error is calculated for each layer by error back propagation. The derivative of ReLU is calculated by the step function.

//Calculate the error of each layer
RealMatrix yd = sub(yp, yt);
RealMatrix bd = mult(step(a), dot(yd, t(removeBias(W))));

Multiply the gradient of each layer by the learning rate to update the weight of each layer.

//Update the weight of each layer by multiplying the gradient by the learning rate
W = sub(W, mult(div(dot(t(b1), yd), batchSize), alpha));
V = sub(V, mult(div(dot(t(x), bd), batchSize), alpha));

4-layer deep learning model

Finally, we will implement a 4-layer deep learning model with one hidden layer added.

DeepLearning2.java


package math.deeplearning.ch10;

import org.apache.commons.math3.linear.RealMatrix;
import java.io.IOException;
import java.util.*;
import static math.deeplearning.common.Util.*;

/**
 *Deep learning(2 hidden layers).
 */
public class DeepLearning2 {
    //Number of lines of training data
    private int M;
    //Number of columns of training data(Number of pixels in the image)
    private int D;
    //Number of classification classes
    private int N;
    //Number of learning
    private int iters;
    //Batch data size
    private int batchSize;
    //Learning rate
    private double alpha;

    //MNIST image data
    private RealMatrix xAll;
    private RealMatrix xTest;
    private RealMatrix ytAll;
    private RealMatrix ytTest;

    //Weight matrix
    private RealMatrix U;
    private RealMatrix V;
    private RealMatrix W;

    public DeepLearning2(int iters, int H, int batchSize, double alpha) throws IOException {
        //Read the MNIST dataset
        xAll = addBiasCol(div(loadMnistImage(MNIST_TRAIN_IMAGE_FILE_NAME), 255));
        xTest = addBiasCol(div(loadMnistImage(MNIST_TEST_IMAGE_FILE_NAME), 255));
        ytAll = oneHotEncode(loadMnistLabel(MNIST_TRAIN_LABEL_FILE_NAME), 10);
        ytTest = oneHotEncode(loadMnistLabel(MNIST_TEST_LABEL_FILE_NAME), 10);

        M = xAll.getRowDimension();
        D = xAll.getColumnDimension();
        N = ytAll.getColumnDimension();

        this.iters = iters;
        this.batchSize = batchSize;
        this.alpha = alpha;

        //Initialize the weight matrix with He Normal
        U = initW(D, H);
        V = initW(H + 1, H);
        W = initW(H + 1, N);
    }

    public static void main(String... args) throws Exception {
        //The number of learnings is 10000, the number of neurons in the hidden layer is 128, the batch size is 512, and the learning rate is 0..Set to 01
        DeepLearning2 dl2 = new DeepLearning2(10000, 128, 512, 0.01);
        //learn
        dl2.learn();
    }

    public void learn() {
        //Initialize random sampling index
        List<Integer> indexes = new ArrayList<>();
        for (int i = 0; i < M; i++) indexes.add(i);

        for (int i = 0; i < iters; i++) {
            //Sampling training data
            List<Integer> index = randIndex(indexes, M, batchSize);
            RealMatrix x = sampling(xAll, index);
            RealMatrix yt = sampling(ytAll, index);

            //Calculate the output value of each layer
            RealMatrix a = dot(x, U);
            RealMatrix b = reLU(a);
            RealMatrix b1 = addBiasCol(b);
            RealMatrix c = dot(b1, V);
            RealMatrix d = reLU(c);
            RealMatrix d1 = addBiasCol(d);
            RealMatrix u = dot(d1, W);
            RealMatrix yp = softmax(u);
            //Calculate the error of each layer
            RealMatrix yd = sub(yp, yt);
            RealMatrix dd = mult(step(c), dot(yd, t(removeBias(W))));
            RealMatrix bd = mult(step(a), dot(dd, t(removeBias(V))));
            //Update the weight of each layer by multiplying the gradient by the learning rate
            W = sub(W, mult(div(dot(t(d1), yd), batchSize), alpha));
            V = sub(V, mult(div(dot(t(b1), dd), batchSize), alpha));
            U = sub(U, mult(div(dot(t(x), bd), batchSize), alpha));

            //Display error and accuracy after learning a certain number of times
            if (i % 100 == 0) {
                RealMatrix p = softmax(dot(addBiasCol(reLU(dot(addBiasCol(reLU(dot(xTest, U))), V))), W));
                System.out.print(i + " " + crossEntropy(ytTest, p) + " ");
                System.out.println(calcAccuracy(ytTest, p));
            }
        }
    }
}

The error of the 3-layer deep learning model was around 0.21 and the correct answer rate was around 94%, but in the 4-layer deep learning model, the hidden layer became 2 layers and the expressive power increased, so the error was around 0.15 and the correct answer rate. Improves to around 95%.

Output of 4-layer deep learning model


0 2.418195100372308 0.1035
100 1.4860509069098333 0.6518
・ ・ ・
9800 0.15087335052084305 0.9552
9900 0.14996068028907877 0.9556

In the weight matrix initialization process, the weight matrix corresponding to the added hidden layer is incremented by one.

3-layer deep learning model


//Initialize the weight matrix with He Normal
V = initW(D, H);
W = initW(H + 1, N);

4-layer deep learning model


//Initialize the weight matrix with He Normal
U = initW(D, H);
V = initW(H + 1, H);
W = initW(H + 1, N);

In the calculation of the output value of each layer, the processing for one additional hidden layer is added.

3-layer deep learning model


//Calculate the output value of each layer
RealMatrix a = dot(x, V);
RealMatrix b = reLU(a);
RealMatrix b1 = addBiasCol(b);
RealMatrix u = dot(b1, W);
RealMatrix yp = softmax(u);

4-layer deep learning model


//Calculate the output value of each layer
RealMatrix a = dot(x, U);
RealMatrix b = reLU(a);
RealMatrix b1 = addBiasCol(b);
RealMatrix c = dot(b1, V);
RealMatrix d = reLU(c);
RealMatrix d1 = addBiasCol(d);
RealMatrix u = dot(d1, W);
RealMatrix yp = softmax(u);

Backpropagation of errors also increases the processing of added hidden layers.

3-layer deep learning model


//Calculate the error of each layer
RealMatrix yd = sub(yp, yt);
RealMatrix bd = mult(step(a), dot(yd, t(removeBias(W))));

4-layer deep learning model


//Calculate the error of each layer
RealMatrix yd = sub(yp, yt);
RealMatrix dd = mult(step(c), dot(yd, t(removeBias(W))));
RealMatrix bd = mult(step(a), dot(dd, t(removeBias(V))));

The weight matrix update process also increases the weight matrix update process corresponding to the added hidden layer.

3-layer deep learning model


//Update the weight of each layer by multiplying the gradient by the learning rate
W = sub(W, mult(div(dot(t(b1), yd), batchSize), alpha));
V = sub(V, mult(div(dot(t(x), bd), batchSize), alpha));

4-layer deep learning model


//Update the weight of each layer by multiplying the gradient by the learning rate
W = sub(W, mult(div(dot(t(d1), yd), batchSize), alpha));
V = sub(V, mult(div(dot(t(b1), dd), batchSize), alpha));
U = sub(U, mult(div(dot(t(x), bd), batchSize), alpha));

Comparing with the code of the 3-layer deep learning model with 1 hidden layer, we can see that only one layer of hidden layer weight matrix and calculation processing is added.

Summary

I started by implementing a linear regression model in Java to solve the regression problem, and finally implemented a simple deep learning model to identify handwritten numbers somewhat correctly. I'm not an author or a publisher, but if you want to understand how deep learning works, please read the book as well. And if you think "I see!", Please try to implement it again in your favorite language. I think that understanding will deepen several times just by reading.

By the way, Python code in "Deeplearning from scratch" is also Implemented in Java. I couldn't implement it neatly in Java, and it stopped at Chapter 6 ... Also, when motivation is restored, I would like to implement Chapter 7 and summarize it in an article.

Thank you for reading until the end!

Recommended Posts

I tried to implement deep learning in Java
I tried to summarize Java learning (1)
I tried to implement Firebase push notification in Java
I tried to implement the Euclidean algorithm in Java
Object-oriented child !? I tried Deep Learning in Java (trial edition)
I tried to implement polymorphic related in Nogizaka.
I tried to output multiplication table in Java
I tried to create Alexa skill in Java
I tried metaprogramming in Java
I tried to create a Clova skill in Java
I tried to make a login function in Java
I tried to implement Stalin sort with Java Collector
[Java] I tried to implement Yahoo API product search
~ I tried to learn functional programming in Java now ~
I tried to find out what changed in Java 9
I tried to interact with Java
I tried using JWT in Java
Try to implement Yubaba in Java
I tried to summarize Java 8 now
I tried to convert a string to a LocalDate type in Java
I tried using Dapr in Java to facilitate microservice development
I tried to implement a buggy web application in Kotlin
I tried to make a client of RESAS-API in Java
I tried using Elasticsearch API in Java
How to implement date calculation in Java
How to implement Kalman filter in Java
First steps for deep learning in Java
Try to implement n-ary addition in Java
I tried to summarize Java lambda expressions
I tried the new era in Java
I tried to implement the Iterator pattern
How to implement coding conventions in Java
I tried to implement ModanShogi with Kinx
I tried to implement Ajax processing of like function in Rails
I tried setting Java beginners to use shortcut keys in eclipse
I tried to make Basic authentication with Java
I want to send an email in Java.
java I tried to break a simple block
I wanted to make (a == 1 && a == 2 && a == 3) true in Java
rsync4j --I want to touch rsync in Java.
[Deep Learning from scratch] in Java 3. Neural network
I tried to build Micra mackerel in 1 hour!
I tried to make a talk application in Java using AI "A3RT"
I tried to develop an application in 2 languages
I tried to implement a server using Netty
I tried to break a block with java (1)
I tried Mastodon's Toot and Streaming API in Java
I tried to implement file upload with Spring MVC
I want to do something like "cls" in Java
[Java 11] I tried to execute Java without compiling with javac
I tried using Google Cloud Vision API in Java
[Java] I tried to solve Paiza's B rank problem
I tried to operate SQS using AWS Java SDK
I want to use ES2015 in Java too! → (´ ・ ω ・ `)
I tried using an extended for statement in Java
Summary of how to implement default arguments in Java
I tried passing Java Silver in 2 weeks without knowing Java
Implement two-step verification in Java
I tried Drools (Java, InputStream)
Implement Basic authentication in Java
Implement math combinations in Java