Deep Learning Java von Grund auf 6.3 Batch-Normalisierung

Inhaltsverzeichnis

6.3 Bewertung der Chargennormalisierung

Layer Erweitert, damit train_flg an die Schnittstelle übergeben werden kann BatchNormLayer wurde erstellt.

Layer.java


public interface Layer {

    INDArray forward(INDArray x);
    INDArray backward(INDArray x);

}

BatchNormLayer.java


public interface BatchNormLayer extends Layer {

    public default INDArray forward(INDArray x) {
        throw new IllegalAccessError();
    }

    INDArray forward(INDArray x, boolean train_flg);

}

Der Rest ist die Implementierung dieser Schnittstellenklasse Dropout und Batch Normalization Die Klasse wurde implementiert. Implementieren Sie außerdem die Klasse MultiLayerNetExtend und führen Sie den folgenden Testcode aus. Machen.

INDArray x_train;
INDArray t_train;
int max_epochs = 20;
int train_size;
int batch_size = 100;
double learning_rate = 0.01;
DataSet trainDataSet;

List<List<Double>> __train(String weight_init_std) {
    MultiLayerNetExtend bn_network = new MultiLayerNetExtend(
        784, new int[] {100, 100, 100, 100, 100}, 10,
        /*activation=*/"relu",
        /*weight_init_std=*/ weight_init_std,
        /*weight_decay_lambda=*/ 0,
        /*use_dropout=*/ false,
        /*dropout_ration=*/ 0.5,
        /*use_batchNorm=*/ true);
    MultiLayerNetExtend network = new MultiLayerNetExtend(
        784, new int[] {100, 100, 100, 100, 100}, 10,
        /*activation=*/"relu",
        /*weight_init_std=*/ weight_init_std,
        /*weight_decay_lambda=*/ 0,
        /*use_dropout=*/ false,
        /*dropout_ration=*/ 0.5,
        /*use_batchNorm=*/ false);
    List<MultiLayerNetExtend> networks = Arrays.asList(bn_network, network);
    Optimizer optimizer = new SGD(learning_rate);
    List<Double> train_acc_list = new ArrayList<>();
    List<Double> bn_train_acc_lsit = new ArrayList<>();
    int iter_per_epoch = Math.max(train_size / batch_size, 1);
    int epoch_cnt = 0;
    for (int i = 0; i < 1000000000; ++i) {
        DataSet sample = trainDataSet.sample(batch_size);
        INDArray x_batch = sample.getFeatureMatrix();
        INDArray t_batch = sample.getLabels();
        for (MultiLayerNetExtend _network : networks) {
            Params grads = _network.gradient(x_batch, t_batch);
            optimizer.update(_network.params, grads);
        }
        if (i % iter_per_epoch == 0) {
            double train_acc = network.accuracy(x_train, t_train);
            double bn_train_acc = bn_network.accuracy(x_train, t_train);
            train_acc_list.add(train_acc);
            bn_train_acc_lsit.add(bn_train_acc);
            System.out.println("epoch:" + epoch_cnt + " | " + train_acc + " - " + bn_train_acc);
            ++epoch_cnt;
            if (epoch_cnt >= max_epochs)
                break;
        }
    }
    return Arrays.asList(train_acc_list, bn_train_acc_lsit);
}

@Test
public void C6_3_2_Batch_Bewertung der Normalisierung() throws IOException {
    // ch06/batch_norm_test.Java-Version von py.
    MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
    x_train = train.normalizedImages();
    t_train = train.oneHotLabels();
    trainDataSet = new DataSet(x_train, t_train);
    train_size = x_train.size(0);

    //Zeichnen eines Diagramms
    File dir = Constants.WeightImages;
    if (!dir.exists()) dir.mkdirs();
    String[] names = {"BatchNormalization", "Normal"};
    Color[] colors = {Color.BLUE, Color.RED};
    INDArray weight_scale_list = Functions.logspace(0, -4, 16);
    INDArray x = Functions.arrange(max_epochs);
    for (int i = 0; i < weight_scale_list.length(); ++i) {
        System.out.println( "============== " + (i+1) + "/16" + " ==============");
        double w = weight_scale_list.getDouble(i);
        List<List<Double>> acc_list = __train(String.valueOf(w));
        GraphImage graph = new GraphImage(640, 480, -1, -0.1, 20, 1.0);
        for (int j = 0; j < names.length; ++j) {
            graph.color(colors[j]);
            graph.textInt(names[j] + " : " + w, 20, 20 * j + 20);
            graph.plot(0, acc_list.get(j).get(0));
            for (int k = 1; k < acc_list.get(j).size(); ++k) {
                graph.line(k - 1, acc_list.get(j).get(k - 1), k, acc_list.get(j).get(k));
                graph.plot(k, acc_list.get(j).get(k));
            }
        }
        File file = new File(dir, "BatchNormalization#" + w + ".png ");
        graph.writeTo(file);
    }
}

Das resultierende Diagramm sieht folgendermaßen aus:

W=1.0 BatchNormalization#1.0.png

W=0.29286444187164307 BatchNormalization#0.29286444187164307.png

W=0.00009999999747378752 BatchNormalization#9.999999747378752E-5.png

Im Gegensatz zu diesem Buch ist das Lernen umso schneller, je kleiner W (Standardabweichung des Anfangsgewichts) ist.

Recommended Posts

Deep Learning Java von Grund auf 6.3 Batch-Normalisierung
Deep Learning Java von Grund auf 6.4 Regularisierung
Lernen Sie Deep Learning von Grund auf in Java.
Deep Learning Java von Grund auf neu Kapitel 1 Einführung
Deep Learning Java von Grund auf 6.1 Parameteraktualisierung
Deep Learning Java von Grund auf neu Kapitel 2 Perceptron
Deep Learning Java von Grund auf neu Kapitel 3 Neuronales Netzwerk
Deep Learning Java von Grund auf 6.2 Anfangswert des Gewichts
Deep Learning Java von Grund auf neu Kapitel 5 Methode zur Fehlerrückübertragung
Schnellstes PC-Setup für tiefes Lernen von Grund auf
[Deep Learning von Grund auf neu] 2. In Java gibt es kein NumPy.
Java-Leben von vorne anfangen
[Deep Learning von Grund auf neu] in Java 1. Zur Zeit Differenzierung und teilweise Differenzierung
Führen Sie eine Batchdatei von Java aus
Java lernen (0)
Java Scratch Scratch
Erste Schritte für tiefes Lernen in Java
Java-Lerntag 5
Java-Lerntag 2
Java-Lerntag 1
Erstellen Sie die VS Code + WSL + Java + Gradle-Umgebung von Grund auf neu
[Hinweis] Erstellen Sie mit Docker eine Java-Umgebung von Grund auf neu
Schnell lernen Java "Einführung?" Teil 3 Von der Programmierung wegreden
Rufen Sie Java von JRuby aus auf
Änderungen von Java 8 zu Java 11
Summe von Java_1 bis 100
Java-Lernen (bedingter Ausdruck)
Java-Lernnotiz (Methode)
Eval Java-Quelle von Java
Greifen Sie über Java auf API.AI zu
Von Java zu Ruby !!
Java-Lernnotiz (Schnittstelle)
Java-Lernnotiz (Vererbung)
Objektorientiertes Kind !? Ich habe Deep Learning mit Java ausprobiert (Testversion)
Lassen Sie uns auf Deep Java Library (DJL) eingehen, eine von AWS veröffentlichte Bibliothek, die Deep Learning in Java verarbeiten kann.