Deep Learning Java von Grund auf 6.4 Regularisierung

Inhaltsverzeichnis

6.4.1 Überlernen

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Lastdämpfung)Einstellungen von===========
double weight_decay_lambda = 0; //Wenn Sie keine Gewichtsabnahme verwenden
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.Zeichnen eines Diagramms=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Erkennungsgenauigkeit beim Überlernen", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit.png "));

overfit.png

6.4.2 Weidht decay

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Lastdämpfung)Einstellungen von===========
// weight_decay_lambda = 0 //Wenn Sie keine Gewichtsabnahme verwenden
double weight_decay_lambda = 0.1;
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.Zeichnen eines Diagramms=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Erkennungsgenauigkeit beim Überlernen mithilfe des Gewichtsabfalls", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit_weight_decay.png "));

overfit_weight_decay.png

6.4.3 Dropout


MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
//Reduzieren Sie die Trainingsdaten, um Übertraining zu reproduzieren
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
//Vorhandensein / Nichtvorhandensein von Dropout, Einstellung des Verhältnisses====================
boolean use_dropout = true; //Falsch, wenn es keinen Ausfall gibt
double dropout_ratio = 0.2;
// =============================================
MultiLayerNetExtend network = new MultiLayerNetExtend(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu",
    /*weight_init_std*/"relu",
    /*weight_decay_lambda*/0,
    /*use_dropout*/use_dropout, /*dropout_ratio*/dropout_ratio,
    /*use_bachnorm*/false);
Trainer trainer = new Trainer(network, x_train, t_train, x_test, t_test,
    /*epochs*/301,
    /*mini_batch_size*/100,
    /*optimizer*/() -> new SGD(0.01),
    /*evaluate_sample_num_per_epoch*/0,
    /*verbose*/true);

trainer.train();
List<Double> train_acc_list = trainer.train_acc_list;
List<Double> test_acc_list = trainer.test_acc_list;

// 3.Zeichnen eines Diagramms=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Erkennungsgenauigkeit in Dropout", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "dropout.png "));

dropout.png

Recommended Posts

Deep Learning Java von Grund auf 6.4 Regularisierung
Lernen Sie Deep Learning von Grund auf in Java.
Deep Learning Java von Grund auf neu Kapitel 1 Einführung
Deep Learning Java von Grund auf 6.1 Parameteraktualisierung
Deep Learning Java von Grund auf neu Kapitel 2 Perceptron
Deep Learning Java von Grund auf 6.3 Batch-Normalisierung
Deep Learning von Grund auf neu Java Kapitel 4 Lernen neuronaler Netze
[Deep Learning von Grund auf neu] in Java 3. Neuronales Netzwerk
Deep Learning Java von Grund auf neu Kapitel 3 Neuronales Netzwerk
Deep Learning Java von Grund auf 6.2 Anfangswert des Gewichts
Deep Learning Java von Grund auf neu Kapitel 5 Methode zur Fehlerrückübertragung
Schnellstes PC-Setup für tiefes Lernen von Grund auf
Java-Leben von vorne anfangen
[Deep Learning von Grund auf neu] 2. In Java gibt es kein NumPy.
[Deep Learning von Grund auf neu] in Java 1. Zur Zeit Differenzierung und teilweise Differenzierung
Java Scratch Scratch
Erste Schritte für tiefes Lernen in Java
Für JAVA-Lernen (2018-03-16-01)
Java-Lerntag 5
Java-Lerntag 2
Java-Lerntag 1
Rufen Sie Java von JRuby aus auf
Erstellen Sie die VS Code + WSL + Java + Gradle-Umgebung von Grund auf neu
Änderungen von Java 8 zu Java 11
Summe von Java_1 bis 100
Java Learning 2 (Lernen Sie die Berechnungsmethode)
Java-Lernen (bedingter Ausdruck)
Java-Lernnotiz (Methode)
Eval Java-Quelle von Java
Java lernen (1) -Hallo Welt
Greifen Sie über Java auf API.AI zu
Java-Lernnotiz (grundlegend)
Von Java zu Ruby !!
[Hinweis] Erstellen Sie mit Docker eine Java-Umgebung von Grund auf neu
Java-Lernnotiz (Schnittstelle)
Java-Lernnotiz (Vererbung)
Schnell lernen Java "Einführung?" Teil 3 Von der Programmierung wegreden
Objektorientiertes Kind !? Ich habe Deep Learning mit Java ausprobiert (Testversion)
Lernen von Java Framework # 1 (Mac-Version)
Java Basic Learning Content 7 (Ausnahme)
Migration von Cobol nach JAVA
Vererbung der JAVA-Lernverlaufsschnittstelle
Java-Lernnotiz (Datentyp)
Java ab Anfänger überschreiben
Elastic Search Indexerstellung aus Java
Java Basic Learning Content 5 (Qualifikation)
Neue Funktionen von Java7 bis Java8
Bücher zum Erlernen von Java
Java-Lerntag 4
Stellen Sie eine Verbindung von Java zu PostgreSQL her
Java-Lernnotiz (logischer Operator)
Java, Instanz für Anfänger
Java-Lernnotiz (abstrakte Klasse)
Java ab Anfänger, Vererbung
Lassen Sie uns auf Deep Java Library (DJL) eingehen, eine von AWS veröffentlichte Bibliothek, die Deep Learning in Java verarbeiten kann.
Verwenden von Docker von Java Gradle
Von ineffektivem Java zu effektivem Java
JavaScript von Java aus gesehen
Java Basic Learning Content 8 (Java-API)