Deep Learning Java von Grund auf 6.2 Anfangswert des Gewichts

Inhaltsverzeichnis

6.2.2 Aktivierungsverteilung für versteckte Schichten

Etwas anders als in diesem Buch habe ich versucht, ein Diagramm zu erstellen, während ich die Standardabweichung des Anfangswertes auf vier Arten geändert habe.

// ch06/weight_init_activtion_histogram.Java-Version von py.
INDArray x = Nd4j.randn(new int[] {1000, 100}); //1000 Daten
int node_num = 100; //Anzahl der Knoten (Neuronen) in jeder verborgenen Schicht
int hidden_layer_size = 5; //5 versteckte Schichten
Map<Integer, INDArray> activations = new HashMap<>(); //Speichern Sie das Aktivierungsergebnis hier
Map<String, Supplier<INDArray>> ws = new LinkedHashMap<>();
//Experimentieren wir, indem wir den Anfangswert ändern!
ws.put("1.0", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(1));
ws.put("0.01", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(0.01));
ws.put("sqrt(1 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(1.0 / node_num)));
ws.put("sqrt(2 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(2.0 / node_num)));
for (String key : ws.keySet()) {
    for (int i = 0; i < hidden_layer_size; ++i) {
        if (i != 0)
            x = activations.get(i - 1);
        INDArray w = ws.get(key).get();
        INDArray a = x.mmul(w);
        //Experimentieren wir, indem wir die Art der Aktivierungsfunktion ändern!
        INDArray z = Functions.sigmoid(a);
        // INDArray z = Functions.relu(a);
        // INDArray z = Functions.tanh(a);
        activations.put(i, z);
    }
    //Zeichnen Sie ein Histogramm
    for (Entry<Integer, INDArray> e : activations.entrySet()) {
        HistogramImage h = new HistogramImage(320, 240, -0.1, -1000, 1, 40000, 50, e.getValue());
        h.writeTo(new File(Constants.WeightImages, key + "-" + (e.getKey() + 1) + "-layer.png "));
    }
}

Das Ergebnis der Verwendung der Aktivierungsfunktion Sigmoid mit unterschiedlichen Standardabweichungen ist wie folgt.

Standardabweichung = 0,01

1-layer 2-layer 3-layer 4-layer 5-layer
0.01-1-layer.png 0.01-2-layer.png 0.01-3-layer.png 0.01-4-layer.png 0.01-5-layer.png

Standardabweichung = 1,0

1-layer 2-layer 3-layer 4-layer 5-layer
1.0-1-layer.png 1.0-2-layer.png 1.0-3-layer.png 1.0-4-layer.png 1.0-5-layer.png

Standardabweichung = $ \ sqrt {\ frac {1} {n}} $

1-layer 2-layer 3-layer 4-layer 5-layer
sqrt(1 div n)-1-layer.png sqrt(1divn)-2-layer.png sqrt(1divn)-3-layer.png sqrt(1divn)-4-layer.png sqrt(1divn)-5-layer.png

6.2.4 Vergleich der anfänglichen Gewichtswerte nach MNIST-Datensatz

// ch06/weight_init_compare.Java-Version von py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
DataSet dataset = new DataSet(x_train, t_train);
int train_size = x_train.size(0);
int batch_size = 128;
int max_iteration = 2000;

// 1:Experimentelle Einstellungen
Map<String, String> weight_init_types = new HashMap<>();
weight_init_types.put("std", "0.01");
weight_init_types.put("Xavier", "sigmoid");
weight_init_types.put("He", "relu");
Optimizer optimizer = new SGD(0.01);

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (Entry<String, String> e : weight_init_types.entrySet()) {
    String key = e.getKey();
    String weight_init_std = e.getValue();
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10, weight_init_std));
    train_loss.put(key, new ArrayList<>());
}

//2:Beginn des Trainings
for (int i = 0; i < max_iteration; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    for (String key : weight_init_types.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizer.update(network.params, grads);

        double loss = network.loss(x_batch, t_batch);
        train_loss.get(key).add(loss);
    }

    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : weight_init_types.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);
        }
    }
}

// 3:Zeichnen eines Diagramms
GraphImage graph = new GraphImage(800, 600, -100, -0.2, max_iteration, 2.5);
Map<String, Color> colors = new HashMap<>();
colors.put("std", Color.GREEN);
colors.put("Xavier", Color.RED);
colors.put("He", Color.BLUE);
double h = 1.5;
for (String key : weight_init_types.keySet()) {
    List<Double> losses = train_loss.get(key);
    graph.color(colors.get(key));
    graph.text(key, 1000, h);
    h += 0.1;
    int step = 10;
    graph.plot(0, losses.get(0));
    for (int i = step; i < max_iteration; i += step) {
        graph.line(i - step, losses.get(i - step), i, losses.get(i));
        graph.plot(i, losses.get(i));
    }
}
graph.color(Color.BLACK);
graph.text(String.format("x=(%f,%f),y=(%f,%f)",
    graph.minX, graph.maxX, graph.minY, graph.maxY), 1000, h);
h += 0.1;
graph.text("Vergleich nach "Anfangsgewicht" für MNIST-Datensatz", 1000, h);
graph.writeTo(Constants.file(Constants.WeightImages, "weight_init_compare.png "));
}

Das resultierende Diagramm sieht folgendermaßen aus:

weight_init_compare.png

Recommended Posts

Deep Learning Java von Grund auf 6.2 Anfangswert des Gewichts
Deep Learning Java von Grund auf 6.4 Regularisierung
Lernen Sie Deep Learning von Grund auf in Java.
Deep Learning Java von Grund auf neu Kapitel 1 Einführung
Deep Learning Java von Grund auf 6.1 Parameteraktualisierung
Deep Learning Java von Grund auf neu Kapitel 2 Perceptron
Deep Learning Java von Grund auf neu Kapitel 3 Neuronales Netzwerk
Deep Learning Java von Grund auf neu Kapitel 5 Methode zur Fehlerrückübertragung
Schnellstes PC-Setup für tiefes Lernen von Grund auf
[Deep Learning von Grund auf neu] 2. In Java gibt es kein NumPy.
[Deep Learning von Grund auf neu] in Java 1. Zur Zeit Differenzierung und teilweise Differenzierung
Java-Lerntag 4
Java-Leben von vorne anfangen
Bedeutung der aus der Java Collection gelernten Schnittstelle
Erste Schritte für tiefes Lernen in Java
Java: So senden Sie Werte von Servlet zu Servlet
Java lernen (0)
Java Scratch Scratch
Die Geschichte des Lernens von Java in der ersten Programmierung
[Java] Ruft mehrere Werte von einem Rückgabewert ab