First steps for deep learning in Java

This article was written with the aim of pushing the back of people who want to do deep learning in Java but can't take the first step. You can easily experience the power of deep learning, so please read it.

Target audience

We assume one of the following people as a reader.

--A Java programmer who wants to do deep learning, but has never written a program in Python. --For some reason, those who want to deep learn in Java --Those who want to use Deeplearning4j

By reading this article, you will be able to build a deep learning development environment using Deeplearning4j in IntelliJ and identify handwritten numbers. It also allows you to run many deep learning samples in addition to number identification. However, you cannot learn the theory of deep learning from this article. To learn the theory of deep learning "[Deep Learning from scratch-The theory and implementation of deep learning learned with Python](https://www.amazon.co.jp/%E3%82%BC%E3%83%] AD% E3% 81% 8B% E3% 82% 89% E4% BD% 9C% E3% 82% 8BDeep-Learning-% E2% 80% 95Python% E3% 81% A7% E5% AD% A6% E3% 81 % B6% E3% 83% 87% E3% 82% A3% E3% 83% BC% E3% 83% 97% E3% 83% A9% E3% 83% BC% E3% 83% 8B% E3% 83% B3 % E3% 82% B0% E3% 81% AE% E7% 90% 86% E8% AB% 96% E3% 81% A8% E5% AE% 9F% E8% A3% 85-% E6% 96% 8E% We recommend that you read E8% 97% A4-% E5% BA% B7% E6% AF% 85 / dp / 4873117585).

Experience first

I think that the first step to start learning deep learning is to actually run the program and "experience" it in order to increase the motivation to learn. See below.

hadwrt.gif

When you enter a handwritten number, this program displays the number that identifies (predicts) it in "Prediction:" at the bottom of the screen.

If you want to write a program that identifies the image of the handwritten number "7" without using a special library, what kind of algorithm would you use? Is it okay if the algorithm distinguishes "7" if there is a horizontal line at the top of the image and there is a downward line from the end point? Can the algorithm correctly identify any of the following as a "7"? And how do you put it into your program? 0.png36.png255.png2671.png1096.png2622.png8387.png

Identifying numbers that are easy for humans is not so easy for computers (when trying to achieve them with algorithms). But deep learning can do this. And in this article, we will actually run the above program.

System requirements

The system requirements for Deeplearning4j are as follows:

--Java 1.7 or higher 64-Bit version (also set JAVA_HOME) --Maven or Gradle --IntelliJ or Eclipse

Environment used for verification

The environment used to write this article is as follows.

$ mvn -version
Apache Maven 3.5.0
Maven home: /usr/share/maven
Java version: 1.8.0_171, vendor: Oracle Corporation
Java home: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "4.13.0-21-generic", arch: "amd64", family: "unix"

$ git --version
git version 2.14.1

Notes

It takes up a lot of disk space, so you might want to build just the dl4j-examples subproject of the dl4j-examples. If you want to build the entire project, you need at least 15GB of free space.

: warning: A lot of jars will be downloaded to the following directories, so when you delete the git clone Deeplearning4j directory, delete this directory as well.

.m2/repository/org/deeplearning4j/

Development environment construction procedure

To build the development environment, just do the following:

$ git clone https://github.com/deeplearning4j/dl4j-examples.git
$ cd dl4j-examples/dl4j-examples
$ mvn clean install

However, since it is used for drawing the screen, please also install OpenJFX (Java FX) if necessary. Without JavaFX, you will get the following error at build time:

[ERROR] /home/tamura/git/dl4j-examples/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/character/harmonies/Piano.java:[4,24]Package javafx.animation does not exist

If you have Ubuntu 17.10, you can install it with:

$ sudo apt-get install openjfx

Operation check

After building the development environment, let's try the identification of handwritten numbers (the sample we saw earlier) that is positioned like "Hello world!" In deep learning. In deep learning, a large amount of data is first "learned" to derive the optimum parameters. Then, based on that, we make a "prediction". The procedure is as follows.

  1. When the build is complete, open the project in IntelliJ.

  2. Open the source code of ʻorg.deeplearning4j.examples.convolution.mnist.MnistClassifier` and execute it (click the execute button of the green triangle on the left side of the editor) (this step does "learning") .. 68747470733a2f2f71696974612d696d6167652d73746f72652e73332e616d617a6f6e6177732e636f6d2f302f34333836392f33616432346264362d383263662d373361372d313831642d3166363139303835643639662e706e67.png
    以下のようなメッセージが出力されます。

    /usr/lib/jvm/java-8-openjdk-amd64/bin/java -javaagent:/home/tamura/idea-IC-181.5087.20 ... (omitted) ... org.deeplearning4j.examples.convolution.mnist.MnistClassifier
    o.d.e.c.m.MnistClassifier - Data load and vectorization...
    o.d.i.r.BaseImageRecordReader - ImageRecordReader: 10 label classes inferred using label generator ParentPathLabelGenerator
    o.d.i.r.BaseImageRecordReader - ImageRecordReader: 10 label classes inferred using label generator ParentPathLabelGenerator
    o.d.e.c.m.MnistClassifier - Network configuration and training...
    o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
    o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 1
    o.n.n.Nd4jBlas - Number of threads used for BLAS: 1
    o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Linux]
    o.n.l.a.o.e.DefaultOpExecutioner - Cores: [4]; Memory: [0.9GB];
    o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [MKL]
    o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
    o.d.o.l.ScoreIterationListener - Score at iteration 0 is 2.4694731759178388
    o.d.o.l.ScoreIterationListener - Score at iteration 10 is 1.078069156582683
    o.d.o.l.ScoreIterationListener - Score at iteration 20 is 0.7327581484283221
    

... (Omitted) ... o.d.o.l.ScoreIterationListener - Score at iteration 1100 is 0.20279510458591593 o.d.o.l.ScoreIterationListener - Score at iteration 1110 is 0.10997898485405874 o.d.e.c.m.MnistClassifier - Completed epoch 0 o.d.e.c.m.MnistClassifier -

========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0.9891
 Precision:       0.9891
 Recall:          0.9890
 F1 Score:        0.9891
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
	0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  973    0    0    0    0    0    2    2    3    0 | 0 = 0
	0 1132    0    1    0    1    1    0    0    0 | 1 = 1
	2    3 1018    1    0    0    1    6    1    0 | 2 = 2
	0    0    1 1000    0    3    0    4    1    1 | 3 = 3
	0    0    1    0  973    0    3    0    0    5 | 4 = 4
	1    0    0    5    0  882    2    1    1    0 | 5 = 5
	5    2    0    0    2    3  944    0    2    0 | 6 = 6
	0    2    4    0    0    0    0 1017    2    3 | 7 = 7
	3    0    2    1    0    0    1    2  961    4 | 8 = 8
	4    2    1    1    3    0    0    6    1  991 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

Process finished with exit code 0
```
  1. Open and execute the source code of ʻorg.deeplearning4j.examples.convolution.mnist.MnistClassifierUI`
  2. When the Java FX screen that accepts handwritten numbers is displayed, enter the numbers (this step does "prediction").

Source code reading

So what is the mechanism for achieving this? "Learning" handwritten number images [MnistClassifier source code](https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/ Let's take a look at convolution / mnist / MnistClassifier.java) from above.

: information_source: A basic knowledge of deep learning is required to understand the following sections.

The first is the download destination of the handwritten digit image, the constants of the temporary directory to unzip it, and the field variables of the logger.

MnistClassifier


public class MnistClassifier {

  private static final Logger log = LoggerFactory.getLogger(MnistClassifier.class);
  private static final String basePath = System.getProperty("java.io.tmpdir") + "/mnist";
  private static final String dataUrl = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";

And it becomes the main () method of this class. Calling this method will start learning. The input image is passed to the input layer as 3D data of 28 pixels each in 1 channel and vertical and horizontal directions. It identifies numbers from 1 to 10, so the number of output layers is 10, the batch size is 54, and the number of epochs is 1.

  public static void main(String[] args) throws Exception {
    int height = 28;
    int width = 28;
    int channels = 1; // single channel for grayscale images
    int outputNum = 10; // 10 digits classification
    int batchSize = 54;
    int nEpochs = 1;
    int iterations = 1;

    int seed = 1234;
    Random randNumGen = new Random(seed);

Next, download the compressed mnist_png.tar.gz of 70,000 handwritten number images from GitHub. And unzip it.

    log.info("Data load and vectorization...");
    String localFilePath = basePath + "/mnist_png.tar.gz";
    if (DataUtilities.downloadFile(dataUrl, localFilePath))
      log.debug("Data downloaded from {}", dataUrl);
    if (!new File(basePath + "/mnist_png").exists())
      DataUtilities.extractTarGz(localFilePath, basePath);

Divide the data for learning (training) (60,000) and the data for test (10,000) and store them in the iterator variables of trainIter and testIter, respectively.

    // vectorization of train data
    File trainData = new File(basePath + "/mnist_png/training");
    FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // parent path as the image label
    ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
    trainRR.initialize(trainSplit);
    DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);

    // pixel values from 0-255 to 0-1 (min-max scaling)
    DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
    scaler.fit(trainIter);
    trainIter.setPreProcessor(scaler);

    // vectorization of test data
    File testData = new File(basePath + "/mnist_png/testing");
    FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
    testRR.initialize(testSplit);
    DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
    testIter.setPreProcessor(scaler); // same normalization for better results

Next, add the learning rate setting to the HashMap with the variable name lrSchedule. If the learning rate is high, learning progresses quickly in the first half, but it does not converge easily in the second half, so the learning rate is lowered according to the number of cases processed. In this program, learning is repeated 1,111 times (= data for training: 60,000 / batch size: 54). The learning rate is gradually lowered according to the number of repetitions.

    log.info("Network configuration and training...");
    Map<Integer, Double> lrSchedule = new HashMap<>();
    lrSchedule.put(0, 0.06); // iteration #, learning rate
    lrSchedule.put(200, 0.05);
    lrSchedule.put(600, 0.028);
    lrSchedule.put(800, 0.0060);
    lrSchedule.put(1000, 0.001);

From here, the main processing of neural network construction is performed. Add a layer to the neural network by calling it with the layer () method of NeuralNetConfiguration.Builder (). No additional input layer is required, so the first layer to add is the Convolution Layer. Next, we are adding a Subsampling Layer. Then repeat it, followed by DenseLayer (fully connected layer), and finally add ʻOutputLayer` (output layer). This is a CNN (Convolutional Neural Network) configuration that is often used in image recognition.

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .l2(0.0005)
        .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, lrSchedule)))
        .weightInit(WeightInit.XAVIER)
        .list()
        .layer(0, new ConvolutionLayer.Builder(5, 5)
            .nIn(channels)
            .stride(1, 1)
            .nOut(20)
            .activation(Activation.IDENTITY)
            .build())
        .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .kernelSize(2, 2)
            .stride(2, 2)
            .build())
        .layer(2, new ConvolutionLayer.Builder(5, 5)
            .stride(1, 1) // nIn need not specified in later layers
            .nOut(50)
            .activation(Activation.IDENTITY)
            .build())
        .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .kernelSize(2, 2)
            .stride(2, 2)
            .build())
        .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
            .nOut(500).build())
        .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
            .nOut(outputNum)
            .activation(Activation.SOFTMAX)
            .build())
        .setInputType(InputType.convolutionalFlat(28, 28, 1)) // InputType.convolutional for normal image
        .backprop(true).pretrain(false).build();

ʻActivation.IDENTITY is an identity function ($ \ scriptsize {f (x) = x} $, that is, do nothing) for the activation function, ʻActivation.RELU is a ReLU function, and ʻActivation.SOFTMAX` is soft. It means to use the max function.

It may be difficult to understand with words alone, so I tried to illustrate the configuration of the neural network. cnn1.png While comparing this figure with the source code, please check Deeplearning4j Cheat Sheet etc. (It will be long, so all Will not explain).

Let's move on. If you call it with the setListeners () method of MultiLayerNetwork, the learning status will be output periodically.

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener(10));
    log.debug("Total num of params: {}", net.numParams());

Finally, call the fit () method to start training with the training data. When the training is complete, give the test data to MultiLayerNetwork.evaluate () and evaluate it. Finally, save the derived parameters together in minist-model.zip.

    // evaluation while training (the score should go down)
    for (int i = 0; i < nEpochs; i++) {
      net.fit(trainIter);
      log.info("Completed epoch {}", i);
      Evaluation eval = net.evaluate(testIter);
      log.info(eval.stats());
      trainIter.reset();
      testIter.reset();
    }
    ModelSerializer.writeModel(net, new File(basePath + "/minist-model.zip"), true);
  }
}

Another class, MnistClassifierUI, reads this minist-model.zip to build a neural network and "predict" handwritten numeric images. I won't go into detail about MnistClassifierUI.

application

Let's change the source code of MnistClassifier a little and do various experiments.

Graph the learning situation

Let's change the listener class given to the setListeners () method of MultiLayerNetwork to something else. Let's set the listener class introduced in this page.

    // net.setListeners(new ScoreIterationListener(10));
    //Comment out the top line and add the bottom four lines
    UIServer uiServer = UIServer.getInstance();
    StatsStorage statsStorage = new InMemoryStatsStorage();
    uiServer.attach(statsStorage);
    net.setListeners(Arrays.asList(new ScoreIterationListener(1), new StatsListener(statsStorage)));

After modifying the source code, run the program again. The following is output to standard output, so

o.d.u.p.PlayUIServer - DL4J UI Server started at http://localhost:9000

When you access http: // localhost: 9000, you will see a graph that visualizes the current learning situation in an easy-to-understand manner as shown below.

Screenshot from 2018-12-12 14-35-36.png

: information_source: Click the "Language" tab on the right side of the screen and select Japanese.

You can see the configuration of this neural network in a simple diagram by clicking the "System" tab.

Screenshot from 2018-12-12 15-21-59.png

Change the optimization algorithm

Next, let's change the optimization algorithm to Stochastic Gradient Descent (SGD). Change the argument Nesterovs of the ʻupdater () method of NeuralNetConfiguration.Builder () to Sgd.

And when I ran the program, I got the following result.

========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0.9698
 Precision:       0.9696
 Recall:          0.9697
 F1 Score:        0.9697
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  969    0    1    0    0    2    3    1    4    0 | 0 = 0
    0 1120    3    2    0    1    3    0    6    0 | 1 = 1
    6    2  993    4    6    3    3    9    6    0 | 2 = 2
    1    0    7  976    0    7    0    9    7    3 | 3 = 3
    1    1    2    0  955    0    5    2    2   14 | 4 = 4
    2    1    0   11    1  866    5    1    3    2 | 5 = 5
   10    3    1    0    6    3  933    0    2    0 | 6 = 6
    2    8   16    2    1    0    0  982    3   14 | 7 = 7
    6    0    1    4    4    5    4    6  941    3 | 8 = 8
    5    7    0    9   11    7    1    5    1  963 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

The accuracy has dropped a little. I've tried a few, but in this case Nesterovs (Nesterov's acceleration gradient descent) seems to be a good choice.

Set the initial value of the weight to zero

Let's dare to move it with incorrect settings. Give WeightInit.ZERO to theweightInit ()method ofNeuralNetConfiguration.Builder ()to set the initial weight to zero.

This way, the score will end around 2.3 with almost no change. And finally, all images are predicted to be "1".

=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
    0  980    0    0    0    0    0    0    0    0 | 0 = 0
    0 1135    0    0    0    0    0    0    0    0 | 1 = 1
    0 1032    0    0    0    0    0    0    0    0 | 2 = 2
    0 1010    0    0    0    0    0    0    0    0 | 3 = 3
    0  982    0    0    0    0    0    0    0    0 | 4 = 4
    0  892    0    0    0    0    0    0    0    0 | 5 = 5
    0  958    0    0    0    0    0    0    0    0 | 6 = 6
    0 1028    0    0    0    0    0    0    0    0 | 7 = 7
    0  974    0    0    0    0    0    0    0    0 | 8 = 8
    0 1009    0    0    0    0    0    0    0    0 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

This is because all weight values are updated uniformly.

Summary

So, I tried to easily identify handwritten numbers using Deeplearning4j. With this, I think we have taken the first step toward deep learning in Java. The git clone source code contains many other samples. As a next step, you may want to try running another program. Those who do not understand the theory are advised to read the books mentioned above.

reference

-Deeplearning4j Official Page -Deeplearning4j's GitHub -Qiita --Introduction to DeepLearning4J

Recommended Posts

First steps for deep learning in Java
For JAVA learning (2018-03-16-01)
Study Deep Learning from scratch in Java.
Learning for the first time java [Introduction]
The story of learning Java in the first programming
I tried to implement deep learning in Java
[Deep Learning from scratch] in Java 3. Neural network
Books used for learning Java
Learning memo when learning Java for the first time (personal learning memo)
First Java development in Eclipse
Deep copy collection in Java
[Deep Learning from scratch] in Java 1. For the time being, differentiation and partial differentiation
Object-oriented child !? I tried Deep Learning in Java (trial edition)
Deep Learning Java from scratch 6.4 Regularization
[For beginners] Run Selenium in Java
Settings for SSL debugging in Java
Java learning (0)
[DL4J] Java deep learning for the first time (handwriting recognition using a fully connected neural network)
Key points for introducing gRPC in Java
[Java] for Each and sorted in Lambda
Precautions when making Docker for deep learning
Deep Learning Java from scratch Chapter 1 Introduction
Deep Learning Java from scratch 6.1 Parameter update
Deep Learning Java from scratch Chapter 2 Perceptron
Deep Learning Java from scratch 6.3 Batch Normalization
[Socket communication (Java)] Impressions of implementing Socket communication in practice for the first time
Programming for the first time in my life Java 1st Hello World
[Deep Learning from scratch] 2. There is no such thing as NumPy in Java.
9 strongest sites for learning Java by self study
ChatWork4j for using the ChatWork API in Java
Java learning day 5
Deep Learning from scratch Java Chapter 4 Neural network learning
Changes in Java 11
Solution for NetBeans 8.2 not working in Java 9 environment
Rock-paper-scissors in Java
Fastest PC setup for deep learning from scratch
Impressions and doubts about using java for the first time in Android Studio
Create your own Android app for Java learning
2017 IDE for Java
Set pop-up display for Java language in vim.
Pi in Java
Deep Learning Java from scratch Chapter 3 Neural networks
Java for statement
FizzBuzz in Java
Compare PDF output in Java for snapshot testing
java learning day 2
Enable / disable SNI in Java for each communication
Things to watch out for in Java equals
java learning day 1
Things to watch out for in future Java development
[Personal memo] Make a simple deep copy in Java
A note for Initializing Fields in the Java tutorial
[For beginners] Minimum sample to display RecyclerView in Java
Get Locale objects for all locales available in Java
Learn for the first time java # 3 expressions and operators
This and that for editing ini in Java. : inieditor-java
[Java] Explains ConcurrentModificationException that occurs in java.util.ArrayList for newcomers
I tried using an extended for statement in Java
Deep Learning Java from scratch 6.2 Initial values of weights
[memo] Generate RSA key pair for SSH in Java
[java] sort in list