table of contents

4.2 Loss function

Functions so that you can easily call various functions that have appeared so far ) It is organized in a class.

4.2.1 Sum of squares error

public static double mean_squared_error(INDArray y, INDArray t) {
    INDArray diff = y.sub(t);
    //It takes an inner product with its own transposed matrix.
    return 0.5 * diff.mmul(diff.transpose()).getDouble(0);

public static double mean_squared_error2(INDArray y, INDArray t) {
    //Use the ND4J squared distance function.
    return 0.5 * (double)y.squaredDistance(t);

INDArray t = Nd4j.create(new double[] {0, 0, 1, 0, 0, 0, 0, 0, 0, 0});
INDArray y = Nd4j.create(new double[] {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0});
assertEquals(0.097500000000000031, mean_squared_error(y, t), 5e-6);
assertEquals(0.097500000000000031, mean_squared_error2(y, t), 5e-6);
// LossFunctions.LossFunction.It can also be achieved using MSE.
assertEquals(0.097500000000000031, LossFunctions.score(t, LossFunctions.LossFunction.MSE, y, 0, 0, false), 5e-6);
y = Nd4j.create(new double[] {0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0});
assertEquals(0.59750000000000003, mean_squared_error(y, t), 5e-6);
assertEquals(0.59750000000000003, mean_squared_error2(y, t), 5e-6);
assertEquals(0.59750000000000003, LossFunctions.score(t, LossFunctions.LossFunction.MSE, y, 0, 0, false), 5e-6);

4.2.2 Cross entropy error

public static double cross_entropy_error(INDArray y, INDArray t) {
    double delta = 1e-7;
    // Python: return -np.sum(t * np.log(y + delta))
    return -t.mul(Transforms.log(y.add(delta))).sumNumber().doubleValue();

INDArray t = Nd4j.create(new double[] {0, 0, 1, 0, 0, 0, 0, 0, 0, 0});
INDArray y = Nd4j.create(new double[] {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0});
assertEquals(0.51082545709933802, cross_entropy_error(y, t), 5e-6);
//It can also be achieved using Loss Functions.
assertEquals(0.51082545709933802, LossFunctions.score(t, LossFunctions.LossFunction.MCXENT, y, 0, 0, false), 5e-6);
y = Nd4j.create(new double[] {0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0});
assertEquals(2.3025840929945458, cross_entropy_error(y, t), 5e-6);

4.2.3 Mini batch learning

Use the ND4J DataSet class to randomly extract samples.

//Load the MNIST dataset.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
assertArrayEquals(new int[] {60000, 784}, train.normalizedImages().shape());
assertArrayEquals(new int[] {60000, 10}, train.oneHotLabels().shape());
//Randomly extract 10 images.
//Store the image and label in the DataSet once, and take out the specified number of samples as a sample.
DataSet ds = new DataSet(train.normalizedImages(), train.oneHotLabels());
DataSet sample = ds.sample(10);
assertArrayEquals(new int[] {10, 784}, sample.getFeatureMatrix().shape());
assertArrayEquals(new int[] {10, 10}, sample.getLabels().shape());
//To confirm that the image of the obtained sample matches the label value.
//Export the sample image as a PNG file.
// one-Converts a hot format label to the original label value. (Find the index of the maximum value of each row)
INDArray indexMax = Nd4j.getExecutioner().exec(new IAMax(sample.getLabels()), 1);
if (!Constants.SampleImagesOutput.exists())
for (int i = 0; i < 10; ++i) {
    //The file name is"(Serial number)-(Label value).png "It will be.
    File f = new File(Constants.SampleImagesOutput,
        String.format("%05d-%d.png ",
            i, indexMax.getInt(i)));
    MNISTImages.writePngFile(sample.getFeatures().getRow(i), train.rows, train.columns, f);

4.2.4 [Batch compatible version] Implementation of cross entropy error

public static double cross_entropy_error2(INDArray y, INDArray t) {
    int batch_size = y.size(0);
    return -t.mul(Transforms.log(y.add(1e-7))).sumNumber().doubleValue() / batch_size;

//For single data
INDArray t = Nd4j.create(new double[] {0, 0, 1, 0, 0, 0, 0, 0, 0, 0});
INDArray y = Nd4j.create(new double[] {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0});
assertEquals(0.51082545709933802, cross_entropy_error2(y, t), 5e-6);
//Batch size=In case of 2 (2 identical data)
t = Nd4j.create(new double[][] {
    {0, 0, 1, 0, 0, 0, 0, 0, 0, 0},
    {0, 0, 1, 0, 0, 0, 0, 0, 0, 0}});
y = Nd4j.create(new double[][] {
    {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0},
    {0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0}});
assertEquals(0.51082545709933802, cross_entropy_error2(y, t), 5e-6);
// todo: one-Implementation of cross entropy error when not represented hot

4.3 Numerical differentiation

4.3.1 Differentiation

public static double numerical_diff_bad(DoubleUnaryOperator f, double x) {
    double h = 10e-50;
    return (f.applyAsDouble(x + h) - f.applyAsDouble(x)) / h;

assertEquals(0.0, (float)1e-50, 1e-52);

4.3.2 Example of numerical differentiation

public static double numerical_diff(DoubleUnaryOperator f, double x) {
    double h = 1e-4;
    return (f.applyAsDouble(x + h) - f.applyAsDouble(x - h)) / (h * 2);

public double function_1(double x) {
    return 0.01 * x * x + 0.1 * x;

public double function_1_diff(double x) {
    return 0.02 * x + 0.1;
assertEquals(0.200, numerical_diff(this::function_1, 5), 5e-6);
assertEquals(0.300, numerical_diff(this::function_1, 10), 5e-6);
assertEquals(0.200, function_1_diff(5), 5e-6);
assertEquals(0.300, function_1_diff(10), 5e-6);

4.3.3 Partial differential

public double function_2(INDArray x) {
    double x0 = x.getDouble(0);
    double x1 = x.getDouble(1);
    return x0 * x0 + x1 * x1;

DoubleUnaryOperator function_tmp1 = x0 -> x0 * x0 + 4.0 * 4.0;
assertEquals(6.00, numerical_diff(function_tmp1, 3.0), 5e-6);
DoubleUnaryOperator function_tmp2 = x1 -> 3.0 * 3.0 + x1 * x1;
assertEquals(8.00, numerical_diff(function_tmp2, 4.0), 5e-6);

4.4 Gradient

public double function_2(INDArray x) {
    double x0 = x.getFloat(0);
    double x1 = x.getFloat(1);
    return x0 * x0 + x1 * x1;
    // return x.mul(x).sumNumber().doubleValue();
    //Alternatively, you can take the inner product with the transposed matrix as follows.
    // return x.mmul(x.transpose()).getDouble(0);

assertEquals("[6.00,8.00]", Util.string(Functions.numerical_gradient(this::function_2, Nd4j.create(new double[] {3.0, 4.0}))));
assertEquals("[0.00,4.00]", Util.string(Functions.numerical_gradient(this::function_2, Nd4j.create(new double[] {0.0, 2.0}))));
assertEquals("[6.00,0.00]", Util.string(Functions.numerical_gradient(this::function_2, Nd4j.create(new double[] {3.0, 0.0}))));

4.4.1 Gradient method

public static INDArray gradient_descent(INDArrayFunction f, INDArray init_x, double lr, int step_num) {
    INDArray x = init_x;
    for (int i = 0; i < step_num; ++i) {
        INDArray grad = Functions.numerical_gradient(f, x);
        INDArray y = x.sub(grad.mul(lr));
//            System.out.printf("step:%d x=%s grad=%s x'=%s%n", i, x, grad, y);
        x = y;
    return x;

// lr = 0.1
INDArray init_x = Nd4j.create(new double[] {-3.0, 4.0});
INDArray r = gradient_descent(this::function_2, init_x, 0.1, 100);
assertEquals("[-0.00,0.00]", Util.string(r));
assertEquals(-6.11110793e-10, r.getDouble(0), 5e-6);
assertEquals(8.14814391e-10, r.getDouble(1), 5e-6);
//Example of learning rate being too large: lr = 10.0
r = gradient_descent(this::function_2, init_x, 10.0, 100);
//It will not be the same as the Python result, but in any case you will not get the correct result.
assertEquals("[-763,389.44,1,017,852.62]", Util.string(r));
//Example where the learning rate is too small: lr = 1e-10
r = gradient_descent(this::function_2, init_x, 1e-10, 100);
assertEquals("[-3.00,4.00]", Util.string(r));

4.4.2 Gradient with respect to neural network

static class simpleNet {

    public final INDArray W;

     *Weight 0.0 to 1.Initialize with a random number in the range of 0.
    public simpleNet() {
        try (Random r = new DefaultRandom()) {
            //Create a matrix of random numbers based on a 2x3 Gaussian distribution.
            W = r.nextGaussian(new int[] {2, 3});
        } catch (Exception e) {
            throw new RuntimeException(e);

     *Weights to make sure the results match this book
     *Allows it to be given from the outside.
    public simpleNet(INDArray W) {
        this.W = W.dup();   //Copy defensively.

    public INDArray predict(INDArray x) {
        return x.mmul(W);

    public double loss(INDArray x, INDArray t) {
        INDArray z = predict(x);
        INDArray y = Functions.softmax(z);
        double loss = Functions.cross_entropy_error(y, t);
        return loss;

//For the weight, give the same value as in this book instead of a random number.
INDArray W = Nd4j.create(new double[][] {
    {0.47355232, 0.9977393, 0.84668094},
    {0.85557411, 0.03563661, 0.69422093},
simpleNet net = new simpleNet(W);
assertEquals("[[0.47,1.00,0.85],[0.86,0.04,0.69]]", Util.string(net.W));
INDArray x = Nd4j.create(new double[] {0.6, 0.9});
INDArray p = net.predict(x);
assertEquals("[1.05,0.63,1.13]", Util.string(p));
assertEquals(2, Functions.argmax(p).getInt(0));
INDArray t = Nd4j.create(new double[] {0, 0, 1});
assertEquals(0.92806853663411326, net.loss(x, t), 5e-6);
//The function definition uses a lambda expression.
INDArrayFunction f = dummy -> net.loss(x, t);
INDArray dW = Functions.numerical_gradient(f, net.W);
assertEquals("[[0.22,0.14,-0.36],[0.33,0.22,-0.54]]", Util.string(dW));

4.5 Implementation of learning algorithm

4.5.1 Two-layer neural network class

The class of the two-layer neural network is TwoLayerNet. Weights and biases are stored in TwoLayerParams instead of Map .. Random numbers use ND4J's Randam i interface. It takes about 5 to 10 minutes in my environment.

TwoLayerNet net = new TwoLayerNet(784, 100, 10);
assertArrayEquals(new int[] {784, 100}, net.parms.get("W1").shape());
assertArrayEquals(new int[] {1, 100}, net.parms.get("b1").shape());
assertArrayEquals(new int[] {100, 10}, net.parms.get("W2").shape());
assertArrayEquals(new int[] {1, 10}, net.parms.get("b2").shape());
try (Random r = new DefaultRandom()) {
    INDArray x = r.nextGaussian(new int[] {100, 784});
    INDArray t = r.nextGaussian(new int[] {100, 10});
    INDArray y = net.predict(x);
    assertArrayEquals(new int[] {100, 10}, y.shape());
    Params grads = net.numerical_gradient(x, t);
    assertArrayEquals(new int[] {784, 100}, grads.get("W1").shape());
    assertArrayEquals(new int[] {1, 100}, grads.get("b1").shape());
    assertArrayEquals(new int[] {100, 10}, grads.get("W2").shape());
    assertArrayEquals(new int[] {1, 10}, grads.get("b2").shape());

4.5.2 Implementation of mini-batch learning

Using MNIST data is very time consuming. In my environment, it takes about 90 seconds for each loop, so if you loop 10,000 times, it will take about 10 days.

//Load the MNIST dataset.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
assertArrayEquals(new int[] {60000, 784}, x_train.shape());
assertArrayEquals(new int[] {60000, 10}, t_train.shape());
List<Double> train_loss_list =  new ArrayList<>();
int iters_num = 10000;
// int train_size = images.size(0);
 int batch_size = 100;
double learning_rate = 0.1;
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
// batch_Data for size is randomly fetched.
for (int i = 0; i < iters_num; ++i) {
    long start = System.currentTimeMillis();
    //Get a mini batch
    DataSet ds = new DataSet(x_train, t_train);
    DataSet sample = ds.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    Params grad =  network.numerical_gradient(x_batch, t_batch);
    network.parms.update((p, a) -> p.subi(a.mul(learning_rate)), grad);
    //Record of learning progress
    double loss = network.loss(x_batch, t_batch);
    System.out.printf("iteration %d loss=%f elapse=%dms%n",
        i, loss, System.currentTimeMillis() - start);

4.5.3 Evaluated with test data

Using MNIST data is very time consuming. In my environment, it takes about 90 seconds for each loop, so if you loop 10,000 times, it will take about 10 days. Therefore, I have not executed it to the end, but it took 4.6 hours to train until the recognition accuracy of both the training data and the test data became 90% or more, and the number of loops was 214. As you can see from the graph in the book, it rises fairly quickly up to about 80%, so it may be better to stop when the recognition accuracy exceeds the threshold value instead of looping 10,000 times.

//Load the MNIST dataset.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertArrayEquals(new int[] {60000, 784}, x_train.shape());
assertArrayEquals(new int[] {60000, 10}, t_train.shape());
List<Double> train_loss_list =  new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();
int iters_num = 10000;
int train_size = x_train.size(0);
int batch_size = 100;
double learning_rate = 0.01;
int iter_per_epoch = Math.max(train_size / batch_size, 1);
TwoLayerNet network = new TwoLayerNet(784, 50, 10);
// batch_Data for size is randomly fetched.
for (int i = 0; i < iters_num; ++i) {
    long start = System.currentTimeMillis();
    //Get a mini batch
    DataSet ds = new DataSet(x_train, t_train);
    DataSet sample = ds.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    Params grad =  network.numerical_gradient(x_batch, t_batch);
    network.parms.update((p, a) -> p.subi(a.mul(learning_rate)), grad);
    //Record of learning progress
    double loss = network.loss(x_batch, t_batch);
    //Calculation of recognition system for each epoch
    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        System.out.printf("train acc, test acc | %s, %s%n",
            train_acc, test_acc);
    System.out.printf("iteration %d loss=%f elapse=%dms%n",
        i, loss, System.currentTimeMillis() - start);

Quick learning Java "Introduction?" Part 3 Talking away from programming