Deep Learning Java from scratch Chapter 5 Error back propagation method

table of contents

5.4 Simple layer implementation

5.4.1 Multiplication layer implementation

static class MulLayer {

    private double x, y;

    public double forward(double x, double y) {
        this.x = x;
        this.y = y;
        return x * y;
    }

    //Since Java cannot return multiple values, it returns as an array.
    public double[] backward(double dout) {
        return new double[] {dout * y, dout * x};
    }
}

double apple = 100;
double apple_num = 2;
double tax = 1.1;
// Layer
MulLayer mul_apple_layer = new MulLayer();
MulLayer mul_tax_layer = new MulLayer();
// forward
double apple_price = mul_apple_layer.forward(apple, apple_num);
double price = mul_tax_layer.forward(apple_price, tax);
assertEquals(220.0, price, 5e-6);
// backward
double dprice = 1;
double[] dapple_price_tax = mul_tax_layer.backward(dprice);
double[] dapple_num = mul_apple_layer.backward(dapple_price_tax[0]);
assertEquals(2.2, dapple_num[0], 5e-6);
assertEquals(110.0, dapple_num[1], 5e-6);
assertEquals(200.0, dapple_price_tax[1], 5e-6);

5.4.2 Implementation of addition layer

static class AddLayer {

    public double forward(double x, double y) {
        return x + y;
    }

    public double[] backward(double dout) {
        return new double[] {dout, dout};
    }
}

double apple = 100;
double apple_num = 2;
double orange = 150;
double orange_num = 3;
double tax = 1.1;
// Layer
MulLayer mul_apple_layer = new MulLayer();
MulLayer mul_orange_layer = new MulLayer();
AddLayer add_apple_orange_layer = new AddLayer();
MulLayer mul_tax_layer = new MulLayer();
// forward
double apple_price = mul_apple_layer.forward(apple, apple_num);
double orange_price = mul_orange_layer.forward(orange, orange_num);
double all_price = add_apple_orange_layer.forward(apple_price, orange_price);
double price = mul_tax_layer.forward(all_price, tax);
// backward
double dprice = 1;
double[] dall_price = mul_tax_layer.backward(dprice);
double[] dapple_dorange_price = add_apple_orange_layer.backward(dall_price[0]);
double[] dorange = mul_orange_layer.backward(dapple_dorange_price[1]);
double[] dapple = mul_apple_layer.backward(dapple_dorange_price[0]);
assertEquals(715.0, price, 5e-6);
assertEquals(110.0, dapple[1], 5e-6);
assertEquals(2.2, dapple[0], 5e-6);
assertEquals(3.3, dorange[0], 5e-6);
assertEquals(165.0, dorange[1], 5e-6);
assertEquals(650.0, dall_price[1], 5e-6);

5.5 Implementation of activation function layer

5.5.1 ReLU layer

The class implementation of the ReLU example using the error backpropagation method is Relu )is.

INDArray x = Nd4j.create(new double[][] {{1.0, -0.5}, {-2.0, 3.0}});
assertEquals("[[1.00,-0.50],[-2.00,3.00]]", Util.string(x));
//The test is different from this book.
Relu relu = new Relu();
INDArray a = relu.forward(x);
//Result of forward
assertEquals("[[1.00,0.00],[0.00,3.00]]", Util.string(a));
// mask
assertEquals("[[1.00,0.00],[0.00,1.00]]", Util.string(relu.mask));
INDArray dout = Nd4j.create(new double[][] {{5, 6}, {7, 8}});
INDArray b = relu.backward(dout);
//Results of backward
assertEquals("[[5.00,0.00],[0.00,8.00]]", Util.string(b));

5.5.2 Sigmoid layer

The implementation class of Sigmoid layer using the error back propagation method is Sigmoidis.

5.6 Implementation of Affine Softmax layer

5.6.1 Affine layer

The implementation class of the Affine layer using the error back propagation method is Affineis.

try (Random r = new DefaultRandom()) {
    INDArray X = r.nextGaussian(new int[] {2});
    INDArray W = r.nextGaussian(new int[] {2, 3});
    INDArray B = r.nextGaussian(new int[] {3});
    assertArrayEquals(new int[] {1, 2}, X.shape());
    assertArrayEquals(new int[] {2, 3}, W.shape());
    assertArrayEquals(new int[] {1, 3}, B.shape());
    INDArray Y = X.mmul(W).addRowVector(B);
    assertArrayEquals(new int[] {1, 3}, Y.shape());
}

5.6.2 Batch version of Affine layer

INDArray X_dot_W = Nd4j.create(new double[][] {{0, 0, 0}, {10, 10, 10}});
INDArray B = Nd4j.create(new double[] {1, 2, 3});
assertEquals("[[0.00,0.00,0.00],[10.00,10.00,10.00]]", Util.string(X_dot_W));
assertEquals("[[1.00,2.00,3.00],[11.00,12.00,13.00]]", Util.string(X_dot_W.addRowVector(B)));

5.6.3 Softmax-with-Loss layer

The implementation class of the Softmax-with-Loss layer using the error back propagation method is [SoftmaxWithLoss](https://github.com/saka1029/Deep.Learning/blob/master/src/main/java/deep/learning/common /SoftmaxWithLoss.java).

5.7 Implementation of error back propagation method

The class of the two-layer neural network using the error back propagation method is [TwoLayerNet](https://github.com/saka1029/Deep.Learning/blob/master/src/main/java/deep/learning/C5/TwoLayerNet. java). There is also TwoLayerNet in Chapter 4, but this is numerical differentiation. It is the one using. In order to make each layer easier to handle in this implementation, the following two interfaces are defined.

public interface Layer {

    INDArray forward(INDArray x);
    INDArray backward(INDArray x);

}

public interface LastLayer {

    double forward(INDArray x, INDArray t);
    INDArray backward(INDArray x);

}

5.7.3 Gradient confirmation of error back propagation method

Gradient comparison with the numerical differentiation method is considerably larger than in this book. Therefore, here it is compared with the gradient by numerical differentiation divided by 3. Functions.average (INDArray) averages all the elements The method you want.

public static double average(INDArray x) {
    // x.length()Returns the total number of elements.
    return x.sumNumber().doubleValue() / x.length();
}
//Read the MNIST training data.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
//Normalized image and one-Take out the first 3 hot labels respectively.
int batch_size = 3;
INDArray x_batch = train.normalizedImages().get(NDArrayIndex.interval(0, batch_size));
INDArray t_batch = train.oneHotLabels().get(NDArrayIndex.interval(0, batch_size));
//Find the gradient by numerical differentiation.
Params grad_numerical = network.numerical_gradient(x_batch, t_batch);
//The gradient is calculated by the error propagation method.
Params grad_backprop = network.gradient(x_batch, t_batch);
//Compare the results of the numerical differentiation and the error propagation method.
double diff_W1 = Functions.average(Transforms.abs(grad_backprop.get("W1").sub(grad_numerical.get("W1"))));
double diff_b1 = Functions.average(Transforms.abs(grad_backprop.get("b1").sub(grad_numerical.get("b1"))));
double diff_W2 = Functions.average(Transforms.abs(grad_backprop.get("W2").sub(grad_numerical.get("W2"))));
double diff_b2 = Functions.average(Transforms.abs(grad_backprop.get("b2").sub(grad_numerical.get("b2"))));
System.out.println("W1=" + diff_W1);
System.out.println("b1=" + diff_b1);
System.out.println("W2=" + diff_W2);
System.out.println("b2=" + diff_b2);
//The difference is a little larger than this book.
assertTrue(diff_b1 < 1e-3);
assertTrue(diff_W2 < 1e-3);
assertTrue(diff_b2 < 1e-3);
assertTrue(diff_W1 < 1e-3);

5.7.4 Learning using the error back propagation method

It is considerably faster than learning using numerical differentiation. In my environment, 10000 loops are completed in about 89 seconds. However, the final recognition accuracy was about 84%, which was inferior to the numerical differentiation. There is probably something wrong with the layer implementation, as the difference is large when comparing the gradients to the numerical differentiation.

//Read the MNIST training data.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
//Read the MNIST test data.
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
DataSet dataSet = new DataSet(x_train, t_train);
int iters_num = 10000;
int train_size = x_train.size(0);
int batch_size = 100;
double learning_rate = 0.1;
List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();
int iter_per_epoch = Math.max(train_size / batch_size, 1);
for (int i = 0; i < iters_num; ++i) {
    DataSet sample = dataSet.sample(batch_size);
    INDArray x_batch = sample.getFeatures();
    INDArray t_batch = sample.getLabels();
    //Gradient is calculated by the error back propagation method
    Params grad = network.gradient(x_batch, t_batch);
    //update
    network.params.update((p, a) -> p.subi(a.mul(learning_rate)), grad);
    double loss = network.loss(x_batch, t_batch);
    train_loss_list.add(loss);
    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.printf("loss=%f train_acc=%f test_acc=%f%n", loss, train_acc, test_acc);
    }
}
assertTrue(train_acc_list.get(train_acc_list.size() - 1) > 0.8);
assertTrue(test_acc_list.get(test_acc_list.size() - 1) > 0.8);

Recommended Posts

Deep Learning Java from scratch Chapter 5 Error back propagation method
Deep Learning Java from scratch Chapter 1 Introduction
Deep Learning Java from scratch Chapter 2 Perceptron
Deep Learning from scratch Java Chapter 4 Neural network learning
Deep Learning Java from scratch Chapter 3 Neural networks
Deep Learning Java from scratch 6.4 Regularization
Study Deep Learning from scratch in Java.
Deep Learning Java from scratch 6.1 Parameter update
Deep Learning Java from scratch 6.3 Batch Normalization
[Deep Learning from scratch] in Java 3. Neural network
Deep Learning Java from scratch 6.2 Initial values of weights
Fastest PC setup for deep learning from scratch
[Deep Learning from scratch] 2. There is no such thing as NumPy in Java.
[Deep Learning from scratch] in Java 1. For the time being, differentiation and partial differentiation
Java learning 2 (learning calculation method)
Java learning memo (method)
Java life starting from scratch
Java Silver exam procedure and learning method
First steps for deep learning in Java
Call Java method from JavaScript executed in Java