Deep Learning Java from scratch Chapter 3 Neural networks

table of contents

3.2 Activation function

3.2.3 Implementation of step function

The step function can be implemented as follows. Here, I tried to make it possible to map INDArray with a simple function (DoubleFunction).

public static double step_function(double x) {
    if (x > 0)
        return 1.0;
    else
        return 0.0;
}

public static <T extends Number> INDArray map(INDArray x, DoubleFunction<T> func) {
    int size = x.length();
    INDArray result = Nd4j.create(size);
    for (int i = 0; i < size; ++i)
        result.put(0, i, func.apply(x.getDouble(i)));
    return result;
}

public static INDArray step_function(INDArray x) {
    return map(x, d -> d > 0.0 ? 1 : 0);
}

INDArray x = Nd4j.create(new double[] {-1.0, 1.0, 2.0});
assertEquals("[-1.00,1.00,2.00]", Util.string(x));
assertEquals("[0.00,1.00,1.00]", Util.string(step_function(x)));

3.2.4 Implementation of sigmoid function

The sigmoid function can be implemented as follows. ND4J provides a sigmoid function in the Transforms class, so you can also use it. ..

public static double sigmoid(double x) {
    return (double)(1.0 / (1.0 + Math.exp(-x)));
}

public static INDArray sigmoid(INDArray x) {
    //Operators cannot be overloaded in Java
    //Described by a method call.
    return Transforms.exp(x.neg()).add(1.0).rdiv(1.0);
    //Alternatively, you can implement it as follows using the map mentioned above.
    // return map(x, d -> sigmoid(d));
}


INDArray x = Nd4j.create(new double[] {-1.0, 1.0, 2.0});
assertEquals("[0.27,0.73,0.88]", Util.string(sigmoid(x)));
assertEquals("[0.27,0.73,0.88]", Util.string(Transforms.sigmoid(x)));
INDArray t = Nd4j.create(new double[] {1.0, 2.0, 3.0});
// A.rdiv(k)Is k divided by each element of A.
assertEquals("[1.00,0.50,0.33]", Util.string(t.rdiv(1.0)));
}

3.2.7 ReLU function

The ReLU function can be implemented as follows: You can also use the relu function in the Transforms class in ND4J.

public static INDArray relu(INDArray x) {
    return map(x, d -> Math.max(0.0, d));
}

INDArray x = Nd4j.create(new double[] {-4, -2, 0, 2, 4});
assertEquals("[0.00,0.00,0.00,2.00,4.00]", Util.string(relu(x)));
assertEquals("[0.00,0.00,0.00,2.00,4.00]", Util.string(Transforms.relu(x)));

3.3 Calculation of multidimensional array

3.3.1 Multidimensional array

//One-dimensional array
INDArray A = Nd4j.create(new double[] {1, 2, 3, 4});
assertEquals("[1.00,2.00,3.00,4.00]", Util.string(A));
//In ND4J, the one-dimensional array is a 1xN two-dimensional array.
assertArrayEquals(new int[] {1, 4}, A.shape());
//In ND4J, the number of dimensions is rank()Obtained by the method.
assertEquals(2, A.rank());
assertEquals(1, A.size(0));  //Number of lines
assertEquals(4, A.size(1));  //Number of columns
//Two-dimensional array
INDArray B = Nd4j.create(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertEquals("[[1.00,2.00],[3.00,4.00],[5.00,6.00]]", Util.string(B));
assertEquals(2, B.rank());
assertArrayEquals(new int[] {3, 2}, B.shape());

3.3.2 Matrix multiplication

In ND4J, the inner product is calculated by INDArray.mmul (INDArray).

INDArray A = Nd4j.create(new double[][] {{1, 2}, {3, 4}});
assertArrayEquals(new int[] {2, 2}, A.shape());
INDArray B = Nd4j.create(new double[][] {{5, 6}, {7, 8}});
assertArrayEquals(new int[] {2, 2}, B.shape());
assertEquals("[[19.00,22.00],[43.00,50.00]]", Util.string(A.mmul(B)));

A = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
assertArrayEquals(new int[] {2, 3}, A.shape());
B = Nd4j.create(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertArrayEquals(new int[] {3, 2}, B.shape());
assertEquals("[[22.00,28.00],[49.00,64.00]]", Util.string(A.mmul(B)));

INDArray C = Nd4j.create(new double[][] {{1, 2}, {3, 4}});
assertArrayEquals(new int[] {2, 2}, C.shape());
assertArrayEquals(new int[] {2, 3}, A.shape());
try {
    //In ND4J, if there is an error in the number of elements in the matrix that takes the inner product,
    //Throws an ND4JIllegalStateException.
    A.mmul(C);
    fail();
} catch (ND4JIllegalStateException e) {
    assertEquals(
        "Cannot execute matrix multiplication: [2, 3]x[2, 2]: "
        + "Column of left array 3 != rows of right 2"
        , e.getMessage());
}

A = Nd4j.create(new double[][] {{1, 2}, {3, 4}, {5, 6}});
assertArrayEquals(new int[] {3, 2}, A.shape());
B = Nd4j.create(new double[] {7, 8});
assertArrayEquals(new int[] {1, 2}, B.shape());
//In ND4J, a one-dimensional array is a 1xN-row matrix.
//Transpose to find the product()Must be transposed in the method.
assertArrayEquals(new int[] {2, 1}, B.transpose().shape());
assertEquals("[23.00,53.00,83.00]", Util.string(A.mmul(B.transpose())));

3.3.3 Neural network matrix multiplication

INDArray X = Nd4j.create(new double[] {1, 2});
assertArrayEquals(new int[] {1, 2}, X.shape());
INDArray W = Nd4j.create(new double[][] {{1, 3, 5}, {2, 4, 6}});
assertEquals("[[1.00,3.00,5.00],[2.00,4.00,6.00]]", Util.string(W));
assertArrayEquals(new int[] {2, 3}, W.shape());
INDArray Y = X.mmul(W);
assertEquals("[5.00,11.00,17.00]", Util.string(Y));

3.4 Implementation of 3-layer neural network

3.4.2 Implementation of signal transmission in each layer

INDArray X = Nd4j.create(new double[] {1.0, 0.5});
INDArray W1 = Nd4j.create(new double[][] {{0.1, 0.3, 0.5}, {0.2, 0.4, 0.6}});
INDArray B1 = Nd4j.create(new double[] {0.1, 0.2, 0.3});
assertArrayEquals(new int[] {2, 3}, W1.shape());
assertArrayEquals(new int[] {1, 2}, X.shape());
assertArrayEquals(new int[] {1, 3}, B1.shape());
INDArray A1 = X.mmul(W1).add(B1);
INDArray Z1 = Transforms.sigmoid(A1);
assertEquals("[0.30,0.70,1.10]", Util.string(A1));
assertEquals("[0.57,0.67,0.75]", Util.string(Z1));

INDArray W2 = Nd4j.create(new double[][] {{0.1, 0.4}, {0.2, 0.5}, {0.3, 0.6}});
INDArray B2 = Nd4j.create(new double[] {0.1, 0.2});
assertArrayEquals(new int[] {1, 3}, Z1.shape());
assertArrayEquals(new int[] {3, 2}, W2.shape());
assertArrayEquals(new int[] {1, 2}, B2.shape());
INDArray A2 = Z1.mmul(W2).add(B2);
INDArray Z2 = Transforms.sigmoid(A2);
assertEquals("[0.52,1.21]", Util.string(A2));
assertEquals("[0.63,0.77]", Util.string(Z2));

INDArray W3 = Nd4j.create(new double[][] {{0.1, 0.3}, {0.2, 0.4}});
INDArray B3 = Nd4j.create(new double[] {0.1, 0.2});
INDArray A3 = Z2.mmul(W3).add(B3);
//ND4J has identity in the Transforms class(INDArray)A method is provided.
INDArray Y = Transforms.identity(A3);
assertEquals("[0.32,0.70]", Util.string(A3));
assertEquals("[0.32,0.70]", Util.string(Y));
// Y.equals(A3)Is true.
assertEquals(A3, Y);

3.4.3 Implementation Summary

public static Map<String, INDArray> init_network() {
    Map<String, INDArray> network = new HashMap<>();
    network.put("W1", Nd4j.create(new double[][] {{0.1, 0.3, 0.5}, {0.2, 0.4, 0.6}}));
    network.put("b1", Nd4j.create(new double[] {0.1, 0.2, 0.3}));
    network.put("W2", Nd4j.create(new double[][] {{0.1, 0.4}, {0.2, 0.5}, {0.3, 0.6}}));
    network.put("b2", Nd4j.create(new double[] {0.1, 0.2}));
    network.put("W3", Nd4j.create(new double[][] {{0.1, 0.3}, {0.2, 0.4}}));
    network.put("b3", Nd4j.create(new double[] {0.1, 0.2}));
    return network;
}

public static INDArray forward(Map<String, INDArray> network, INDArray x) {
    INDArray W1 = network.get("W1");
    INDArray W2 = network.get("W2");
    INDArray W3 = network.get("W3");
    INDArray b1 = network.get("b1");
    INDArray b2 = network.get("b2");
    INDArray b3 = network.get("b3");

    INDArray a1 = x.mmul(W1).add(b1);
    INDArray z1 = Transforms.sigmoid(a1);
    INDArray a2 = z1.mmul(W2).add(b2);
    INDArray z2 = Transforms.sigmoid(a2);
    INDArray a3 = z2.mmul(W3).add(b3);
    INDArray y = Transforms.identity(a3);
    return y;
}

Map<String, INDArray> network = init_network();
INDArray x = Nd4j.create(new double[] {1.0, 0.5});
INDArray y = forward(network, x);
assertEquals("[0.32,0.70]", Util.string(y));

3.5 Output layer design

3.5.1 Identity and softmax functions

INDArray a = Nd4j.create(new double[] {0.3, 2.9, 4.0});
//Exponential
INDArray exp_a = Transforms.exp(a);
assertEquals("[1.35,18.17,54.60]", Util.string(exp_a));
//Sum of exponential functions
Number sum_exp_a = exp_a.sumNumber();
assertEquals(74.1221542102, sum_exp_a.doubleValue(), 5e-6);
//Softmax function
INDArray y = exp_a.div(sum_exp_a);
assertEquals("[0.02,0.25,0.74]", Util.string(y));

3.5.2 Precautions for implementing softmax function

public static INDArray softmax_wrong(INDArray a) {
    INDArray exp_a = Transforms.exp(a);
    Number sum_exp_a = exp_a.sumNumber();
    INDArray y = exp_a.div(sum_exp_a);
    return y;
}

public static INDArray softmax_right(INDArray a) {
    Number c = a.maxNumber();
    INDArray exp_a = Transforms.exp(a.sub(c));
    Number sum_exp_a = exp_a.sumNumber();
    INDArray y = exp_a.div(sum_exp_a);
    return y;
}

INDArray a = Nd4j.create(new double[] {1010, 1000, 990});
//Not calculated correctly
assertEquals("[NaN,NaN,NaN]", Util.string(Transforms.exp(a).div(Transforms.exp(a).sumNumber())));
Number c = a.maxNumber();
assertEquals("[0.00,-10.00,-20.00]", Util.string(a.sub(c)));
assertEquals("[1.00,0.00,0.00]", Util.string(Transforms.exp(a.sub(c)).div(Transforms.exp(a.sub(c)).sumNumber())));

//mistake
assertEquals("[NaN,NaN,NaN]", Util.string(softmax_wrong(a)));
//correct
assertEquals("[1.00,0.00,0.00]", Util.string(softmax_right(a)));
//Correct softmax for ND4J(INDArray)Is prepared.
assertEquals("[1.00,0.00,0.00]", Util.string(Transforms.softmax(a)));

3.5.3 Features of softmax function

INDArray a = Nd4j.create(new double[] {0.3, 2.9, 4.0});
INDArray y = Transforms.softmax(a);
assertEquals("[0.02,0.25,0.74]", Util.string(y));
//The sum is 1.
assertEquals(1.0, y.sumNumber().doubleValue(), 5e-6);

3.6 Handwritten digit recognition

3.6.1 MNIST dataset

MNIST data is read by MNISTImages class /java/deep/learning/common/MNISTImages.java). This class reads the file after downloading and unzipping the data in THE MNIST DATABASE of handwritten digits.

//Load the MNIST dataset into the MNISTImages class.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
assertEquals(60000, train.size);
assertEquals(784, train.imageSize);
assertEquals(10000, test.size);
assertEquals(784, test.imageSize);

//The first 100 images of the training data are output as PNG.
if (!Constants.TrainImagesOutput.exists())
    Constants.TrainImagesOutput.mkdirs();
for (int i = 0; i < 100; ++i) {
    File image = new File(Constants.TrainImagesOutput,
        String.format("%05d-%d.png ", i, train.label(i)));
    train.writePngFile(i, image);
}

The image actually output is as follows. The file name is "[serial number 5 digits]-[label] .png ".

00000-5.png 00000-5.png 00001-0.png 00001-0.png 00002-4.png 00002-4.png 00003-1.png 00003-1.png 00004-9.png 00004-9.png 00005-2.png 00005-2.png 00006-1.png 00006-1.png 00007-3.png 00007-3.png 00008-1.png 00008-1.png 00009-4.png 00009-4.png .....

3.6.2 Neural network inference processing

Sample weight data (sample_weight.pkl) is [SampleWeight](https://github.com/ saka1029 / Deep.Learning / blob / master / src / main / java / deep / learning / common / SampleWeight.java) Load using the class. However, sample_weight.pkl is data serialized in Python and cannot be read directly in Java. Therefore, sample_weight.pkl is once converted to text using the following Python program. The data after text conversion is SampleWeight.txt. The SampleWeight class reads this text file.

sample_weight.py


import pickle
import numpy

pkl = "sample_weight.pkl"
with open(pkl, "rb") as f:
	network = pickle.load(f)
for k, v in network.items():
    print(k, end="")
    dim = v.ndim
    for d in v.shape:
        print("", d, end="")
    print()
    for e in v.flatten():
        print(e)

The code that actually reads and infers is as follows. Recognition accuracy is [Deep Learning made from scratch] It is 93.52%, which is the same as (https://www.oreilly.co.jp/books/9784873117584/).

static INDArray normalize(byte[][] images) {
    int imageCount = images.length;
    int imageSize = images[0].length;
    INDArray norm = Nd4j.create(imageCount, imageSize);
    for (int i = 0; i < imageCount; ++i)
        for (int j = 0; j < imageSize; ++j)
            norm.putScalar(i, j, (images[i][j] & 0xff) / 255.0);
    return norm;
}

static INDArray predict(Map<String, INDArray> network, INDArray x) {
    INDArray W1 = network.get("W1");
    INDArray W2 = network.get("W2");
    INDArray W3 = network.get("W3");
    INDArray b1 = network.get("b1");
    INDArray b2 = network.get("b2");
    INDArray b3 = network.get("b3");

    //If you do the following, an error will occur in batch processing.
    // INDArray a1 = x.mmul(W1).add(b1);
    // x.mmul(W1)This is because b1 is one-dimensional even though the result of is a two-dimensional array.
    // add(INDArray)Does not broadcast automatically.
    //It can also be explicitly broadcast as follows:
    // INDArray a1 = x.mmul(W1).add(b1.broadcast(x.size(0), b1.size(1)));
    INDArray a1 = x.mmul(W1).addRowVector(b1);
    INDArray z1 = Transforms.sigmoid(a1);
    INDArray a2 = z1.mmul(W2).addRowVector(b2);
    INDArray z2 = Transforms.sigmoid(a2);
    INDArray a3 = z2.mmul(W3).addRowVector(b3);
    INDArray y = Transforms.softmax(a3);

    return y;
}

//Load the test image.
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
//Load sample weight data.
Map<String, INDArray> network = SampleWeight.read(Constants.SampleWeights);
//Normalize the image(0-255 -> 0.0-1.0)
INDArray x = test.normalizedImages();
int size = x.size(0);
int accuracy_cnt = 0;
for (int i = 0; i < size; ++i) {
    INDArray y = predict(network, x.getRow(i));
    //The last argument, 1 represents the dimension.
    INDArray max = Nd4j.getExecutioner().exec(new IAMax(y), 1);
    if (max.getInt(0) == test.label(i))
        ++accuracy_cnt;
}
//        System.out.printf("Accuracy:%f%n", (double) accuracy_cnt / size);
assertEquals(10000, size);
assertEquals(9352, accuracy_cnt);

3.6.3 Batch processing

NDArrayIndex to retrieve batch size data from INDArray Use doc / org / nd4j / linalg / indexing / NDArrayIndex.html). You can retrieve the elements from start to end with NDArrayIndex.interval (int start, int end) (Start ≤ i <end). You can also use the IAMax class to do the same as NumPy's argmax function.

int batch_size = 100;
//Load the test image.
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
//Load sample weight data.
Map<String, INDArray> network = SampleWeight.read(Constants.SampleWeights);
//Normalize the image(0-255 -> 0.0-1.0)
INDArray x = test.normalizedImages();
int size = x.size(0);
int accuracy_cnt = 0;
for (int i = 0; i < size; i += batch_size) {
    //Extract images for batch size and predict()Is called.
    INDArray y = predict(network, x.get(NDArrayIndex.interval(i, i + batch_size)));
    //The last argument, 1 represents the dimension.
    INDArray max = Nd4j.getExecutioner().exec(new IAMax(y), 1);
    for (int j = 0; j < batch_size; ++j)
        if (max.getInt(j) == test.label(i + j))
            ++accuracy_cnt;
}
//        System.out.printf("Accuracy:%f%n", (double) accuracy_cnt / size);
assertEquals(10000, size);
assertEquals(9352, accuracy_cnt);

In my environment, batch processing was able to speed up from 2.7 seconds to 0.5 seconds. I don't use GPU.

Recommended Posts

Deep Learning Java from scratch Chapter 3 Neural networks
Deep Learning Java from scratch Chapter 1 Introduction
Deep Learning Java from scratch Chapter 2 Perceptron
Deep Learning Java from scratch 6.4 Regularization
Deep Learning Java from scratch Chapter 5 Error back propagation method
Study Deep Learning from scratch in Java.
Deep Learning Java from scratch 6.1 Parameter update
Fastest PC setup for deep learning from scratch
[Deep Learning from scratch] 2. There is no such thing as NumPy in Java.
[Deep Learning from scratch] in Java 1. For the time being, differentiation and partial differentiation
Java life starting from scratch
Java learning (0)
Java scratch scratch
First steps for deep learning in Java
I tried to implement deep learning in Java
Effective Java Chapter 2
Java learning day 5
Effective Java Chapter 6 34-35
Effective Java Chapter 4 15-22
Effective Java Chapter 3
java learning day 2
java learning day 1
Build VS Code + WSL + Java + Gradle environment from scratch
[Note] Create a java environment from scratch with docker
Quick learning Java "Introduction?" Part 3 Talking away from programming