Deep Learning Java from scratch 6.2 Initial values of weights

table of contents

6.2.2 Hidden layer activation distribution

A little different from this book, I tried to create a graph while changing the standard deviation of the initial value in 4 ways.

// ch06/weight_init_activtion_histogram.Java version of py.
INDArray x = Nd4j.randn(new int[] {1000, 100}); //1000 data
int node_num = 100; //Number of nodes (neurons) in each hidden layer
int hidden_layer_size = 5; //5 hidden layers
Map<Integer, INDArray> activations = new HashMap<>(); //Store the activation result here
Map<String, Supplier<INDArray>> ws = new LinkedHashMap<>();
//Let's experiment by changing the initial value!
ws.put("1.0", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(1));
ws.put("0.01", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(0.01));
ws.put("sqrt(1 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(1.0 / node_num)));
ws.put("sqrt(2 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(2.0 / node_num)));
for (String key : ws.keySet()) {
    for (int i = 0; i < hidden_layer_size; ++i) {
        if (i != 0)
            x = activations.get(i - 1);
        INDArray w = ws.get(key).get();
        INDArray a = x.mmul(w);
        //Let's experiment by changing the type of activation function!
        INDArray z = Functions.sigmoid(a);
        // INDArray z = Functions.relu(a);
        // INDArray z = Functions.tanh(a);
        activations.put(i, z);
    }
    //Draw histogram
    for (Entry<Integer, INDArray> e : activations.entrySet()) {
        HistogramImage h = new HistogramImage(320, 240, -0.1, -1000, 1, 40000, 50, e.getValue());
        h.writeTo(new File(Constants.WeightImages, key + "-" + (e.getKey() + 1) + "-layer.png "));
    }
}

The result of executing with the activation function sigmoid with different standard deviations is as follows.

Standard deviation = 0.01

1-layer 2-layer 3-layer 4-layer 5-layer
0.01-1-layer.png 0.01-2-layer.png 0.01-3-layer.png 0.01-4-layer.png 0.01-5-layer.png

Standard deviation = 1.0

1-layer 2-layer 3-layer 4-layer 5-layer
1.0-1-layer.png 1.0-2-layer.png 1.0-3-layer.png 1.0-4-layer.png 1.0-5-layer.png

Standard deviation = $ \ sqrt {\ frac {1} {n}} $

1-layer 2-layer 3-layer 4-layer 5-layer
sqrt(1 div n)-1-layer.png sqrt(1divn)-2-layer.png sqrt(1divn)-3-layer.png sqrt(1divn)-4-layer.png sqrt(1divn)-5-layer.png

6.2.4 Comparison of initial weight values by MNIST dataset

// ch06/weight_init_compare.Java version of py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
DataSet dataset = new DataSet(x_train, t_train);
int train_size = x_train.size(0);
int batch_size = 128;
int max_iteration = 2000;

// 1:Experiment settings
Map<String, String> weight_init_types = new HashMap<>();
weight_init_types.put("std", "0.01");
weight_init_types.put("Xavier", "sigmoid");
weight_init_types.put("He", "relu");
Optimizer optimizer = new SGD(0.01);

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (Entry<String, String> e : weight_init_types.entrySet()) {
    String key = e.getKey();
    String weight_init_std = e.getValue();
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10, weight_init_std));
    train_loss.put(key, new ArrayList<>());
}

//2:Start of training
for (int i = 0; i < max_iteration; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    for (String key : weight_init_types.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizer.update(network.params, grads);

        double loss = network.loss(x_batch, t_batch);
        train_loss.get(key).add(loss);
    }

    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : weight_init_types.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);
        }
    }
}

// 3:Drawing a graph
GraphImage graph = new GraphImage(800, 600, -100, -0.2, max_iteration, 2.5);
Map<String, Color> colors = new HashMap<>();
colors.put("std", Color.GREEN);
colors.put("Xavier", Color.RED);
colors.put("He", Color.BLUE);
double h = 1.5;
for (String key : weight_init_types.keySet()) {
    List<Double> losses = train_loss.get(key);
    graph.color(colors.get(key));
    graph.text(key, 1000, h);
    h += 0.1;
    int step = 10;
    graph.plot(0, losses.get(0));
    for (int i = step; i < max_iteration; i += step) {
        graph.line(i - step, losses.get(i - step), i, losses.get(i));
        graph.plot(i, losses.get(i));
    }
}
graph.color(Color.BLACK);
graph.text(String.format("x=(%f,%f),y=(%f,%f)",
    graph.minX, graph.maxX, graph.minY, graph.maxY), 1000, h);
h += 0.1;
graph.text("Comparison by "initial weight" for MNIST dataset", 1000, h);
graph.writeTo(Constants.file(Constants.WeightImages, "weight_init_compare.png "));
}

The resulting graph looks like this:

weight_init_compare.png

Recommended Posts

Deep Learning Java from scratch 6.2 Initial values of weights
Deep Learning Java from scratch 6.4 Regularization
Study Deep Learning from scratch in Java.
Deep Learning Java from scratch Chapter 1 Introduction
Deep Learning Java from scratch 6.1 Parameter update
Deep Learning Java from scratch Chapter 2 Perceptron
Deep Learning Java from scratch Chapter 3 Neural networks
Deep Learning Java from scratch Chapter 5 Error back propagation method
Fastest PC setup for deep learning from scratch
[Deep Learning from scratch] 2. There is no such thing as NumPy in Java.
[Deep Learning from scratch] in Java 1. For the time being, differentiation and partial differentiation
4th day of java learning
Java life starting from scratch
Significance of interface learned from Java Collection
First steps for deep learning in Java
Java: How to send values from Servlet to Servlet
Java learning (0)
Java scratch scratch
The story of learning Java in the first programming
I tried to implement deep learning in Java
[Java] Get multiple values from one return value