Deep Learning Java from scratch 6.4 Régularisation

table des matières

6.4.1 Surapprentissage

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Amortissement de la charge)paramètres de===========
double weight_decay_lambda = 0; //Lorsque la décroissance du poids n'est pas utilisée
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.Dessiner un graphique=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Précision de reconnaissance dans le surapprentissage", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit.png "));

overfit.png

6.4.2 Weidht decay

MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Amortissement de la charge)paramètres de===========
// weight_decay_lambda = 0 //Lorsque vous n'utilisez pas la décroissance du poids
double weight_decay_lambda = 0.1;
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu", /*weight_init_std*/"relu",
    /*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;

List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();

int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    Params grads = network.gradient(x_batch, t_batch);
    optimizer.update(network.params, grads);

    if (i % iter_per_epoch == 0) {
        double train_acc = network.accuracy(x_train, t_train);
        double test_acc = network.accuracy(x_test, t_test);
        train_acc_list.add(train_acc);
        test_acc_list.add(test_acc);
        System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
        ++epoch_cnt;
        if (epoch_cnt >= max_epochs)
            break;
    }
}

// 3.Dessiner un graphique=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Précision de reconnaissance dans le surapprentissage en utilisant la décroissance du poids", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit_weight_decay.png "));

overfit_weight_decay.png

6.4.3 Dropout


MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
//Réduisez les données d'entraînement pour reproduire le surentraînement
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
//Présence / absence de décrochage, réglage du ratio====================
boolean use_dropout = true; //Faux quand il n'y a pas de décrochage
double dropout_ratio = 0.2;
// =============================================
MultiLayerNetExtend network = new MultiLayerNetExtend(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
    /*activation*/"relu",
    /*weight_init_std*/"relu",
    /*weight_decay_lambda*/0,
    /*use_dropout*/use_dropout, /*dropout_ratio*/dropout_ratio,
    /*use_bachnorm*/false);
Trainer trainer = new Trainer(network, x_train, t_train, x_test, t_test,
    /*epochs*/301,
    /*mini_batch_size*/100,
    /*optimizer*/() -> new SGD(0.01),
    /*evaluate_sample_num_per_epoch*/0,
    /*verbose*/true);

trainer.train();
List<Double> train_acc_list = trainer.train_acc_list;
List<Double> test_acc_list = trainer.test_acc_list;

// 3.Dessiner un graphique=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Précision de reconnaissance dans Dropout", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
    graph.color(Color.BLUE);
    graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
    graph.plot(i, train_acc_list.get(i));
    graph.color(Color.RED);
    graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
    graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "dropout.png "));

dropout.png

Recommended Posts

Deep Learning Java from scratch 6.4 Régularisation
Étudiez le Deep Learning à partir de zéro en Java.
Deep Learning Java à partir de zéro Chapitre 1 Introduction
Deep Learning Java from scratch 6.1 Mise à jour des paramètres
Deep Learning Java à partir de zéro Chapitre 2 Perceptron
Deep Learning Java from scratch 6.3 Normalisation par lots
Deep Learning from scratch Java Chapter 4 Apprentissage des réseaux de neurones
[Deep Learning from scratch] dans Java 3. Réseau neuronal
Deep Learning Java à partir de zéro Chapitre 3 Réseau neuronal
Deep Learning Java from scratch 6.2 Valeur initiale du poids
Deep Learning Java from scratch Chapter 5 Méthode de propagation de retour d'erreur
Configuration PC la plus rapide pour un apprentissage en profondeur à partir de zéro
La vie Java à partir de zéro
[Apprentissage profond à partir de zéro] 2. Il n'existe pas de NumPy en Java.
[Deep Learning from scratch] en Java 1. Pour le moment, différenciation et différenciation partielle
Java scratch scratch
Premiers pas pour l'apprentissage profond en Java
Pour l'apprentissage JAVA (2018-03-16-01)
Jour d'apprentissage Java 5
java learning day 2
java learning day 1
Appeler Java depuis JRuby
Créer un environnement VS Code + WSL + Java + Gradle à partir de zéro
Changements de Java 8 à Java 11
Somme de Java_1 à 100
Java Learning 2 (Apprenez la méthode de calcul)
apprentissage java (expression conditionnelle)
Mémo d'apprentissage Java (méthode)
Évaluer la source Java à partir de Java
Apprendre Java (1) - Hello World
Interface d'historique d'apprentissage JAVA
Accédez à API.AI depuis Java
Mémo d'apprentissage Java (basique)
De Java à Ruby !!
[Note] Créez un environnement Java à partir de zéro avec docker
Mémo d'apprentissage Java (interface)
Mémo d'apprentissage Java (héritage)
Apprentissage rapide de Java "Introduction?" Partie 3 Parler de programmation
Enfant orienté objet!? J'ai essayé le Deep Learning avec Java (édition d'essai)
Apprentissage du framework Java # 1 (version Mac)
Contenu d'apprentissage de base Java 7 (exception)
Migration de Cobol vers JAVA
Héritage de l'interface de l'historique d'apprentissage JAVA
Mémo d'apprentissage Java (type de données)
Java à partir du débutant, remplacer
Création d'index Elastic Search à partir de Java
Contenu d'apprentissage de base Java 5 (qualificatif)
Nouvelles fonctionnalités de Java7 à Java8
Livres utilisés pour apprendre Java
java learning day 4
Connectez-vous de Java à PostgreSQL
Mémo d'apprentissage Java (opérateur logique)
Java, instance à partir du débutant
Mémo d'apprentissage Java (classe abstraite)
Java à partir de débutant, héritage
Voyons la Deep Java Library (DJL), une bibliothèque capable de gérer Deep Learning en Java, publiée par AWS.
Utilisation de Docker depuis Java Gradle
De Java inefficace à Java efficace
JavaScript vu de Java
Contenu d'apprentissage de base Java 8 (API Java)