A little different from this book, I tried to create a graph while changing the standard deviation of the initial value in 4 ways.
// ch06/weight_init_activtion_histogram.Java version of py.
INDArray x = Nd4j.randn(new int[] {1000, 100}); //1000 data
int node_num = 100; //Number of nodes (neurons) in each hidden layer
int hidden_layer_size = 5; //5 hidden layers
Map<Integer, INDArray> activations = new HashMap<>(); //Store the activation result here
Map<String, Supplier<INDArray>> ws = new LinkedHashMap<>();
//Let's experiment by changing the initial value!
ws.put("1.0", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(1));
ws.put("0.01", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(0.01));
ws.put("sqrt(1 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(1.0 / node_num)));
ws.put("sqrt(2 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(2.0 / node_num)));
for (String key : ws.keySet()) {
for (int i = 0; i < hidden_layer_size; ++i) {
if (i != 0)
x = activations.get(i - 1);
INDArray w = ws.get(key).get();
INDArray a = x.mmul(w);
//Let's experiment by changing the type of activation function!
INDArray z = Functions.sigmoid(a);
// INDArray z = Functions.relu(a);
// INDArray z = Functions.tanh(a);
activations.put(i, z);
}
//Draw histogram
for (Entry<Integer, INDArray> e : activations.entrySet()) {
HistogramImage h = new HistogramImage(320, 240, -0.1, -1000, 1, 40000, 50, e.getValue());
h.writeTo(new File(Constants.WeightImages, key + "-" + (e.getKey() + 1) + "-layer.png "));
}
}
The result of executing with the activation function sigmoid with different standard deviations is as follows.
Standard deviation = 0.01
1-layer | 2-layer | 3-layer | 4-layer | 5-layer |
---|---|---|---|---|
Standard deviation = 1.0
1-layer | 2-layer | 3-layer | 4-layer | 5-layer |
---|---|---|---|---|
Standard deviation = $ \ sqrt {\ frac {1} {n}} $
1-layer | 2-layer | 3-layer | 4-layer | 5-layer |
---|---|---|---|---|
// ch06/weight_init_compare.Java version of py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
DataSet dataset = new DataSet(x_train, t_train);
int train_size = x_train.size(0);
int batch_size = 128;
int max_iteration = 2000;
// 1:Experiment settings
Map<String, String> weight_init_types = new HashMap<>();
weight_init_types.put("std", "0.01");
weight_init_types.put("Xavier", "sigmoid");
weight_init_types.put("He", "relu");
Optimizer optimizer = new SGD(0.01);
Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (Entry<String, String> e : weight_init_types.entrySet()) {
String key = e.getKey();
String weight_init_std = e.getValue();
networks.put(key, new MultiLayerNet(
784, new int[] {100, 100, 100, 100}, 10, weight_init_std));
train_loss.put(key, new ArrayList<>());
}
//2:Start of training
for (int i = 0; i < max_iteration; ++i) {
DataSet sample = dataset.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
for (String key : weight_init_types.keySet()) {
MultiLayerNet network = networks.get(key);
Params grads = network.gradicent(x_batch, t_batch);
optimizer.update(network.params, grads);
double loss = network.loss(x_batch, t_batch);
train_loss.get(key).add(loss);
}
if (i % 100 == 0) {
System.out.println("===========" + "iteration:" + i + "===========");
for (String key : weight_init_types.keySet()) {
double loss = networks.get(key).loss(x_batch, t_batch);
System.out.println(key + ":" + loss);
}
}
}
// 3:Drawing a graph
GraphImage graph = new GraphImage(800, 600, -100, -0.2, max_iteration, 2.5);
Map<String, Color> colors = new HashMap<>();
colors.put("std", Color.GREEN);
colors.put("Xavier", Color.RED);
colors.put("He", Color.BLUE);
double h = 1.5;
for (String key : weight_init_types.keySet()) {
List<Double> losses = train_loss.get(key);
graph.color(colors.get(key));
graph.text(key, 1000, h);
h += 0.1;
int step = 10;
graph.plot(0, losses.get(0));
for (int i = step; i < max_iteration; i += step) {
graph.line(i - step, losses.get(i - step), i, losses.get(i));
graph.plot(i, losses.get(i));
}
}
graph.color(Color.BLACK);
graph.text(String.format("x=(%f,%f),y=(%f,%f)",
graph.minX, graph.maxX, graph.minY, graph.maxY), 1000, h);
h += 0.1;
graph.text("Comparison by "initial weight" for MNIST dataset", 1000, h);
graph.writeTo(Constants.file(Constants.WeightImages, "weight_init_compare.png "));
}
The resulting graph looks like this:
Recommended Posts