Deep Learning Java from scratch 6.3 Normalisation par lots

table des matières

6.3 Évaluation de la normalisation des lots

Layer Étendu pour permettre à train_flg d'être transmis à l'interface BatchNormLayer a été créé.

Layer.java


public interface Layer {

    INDArray forward(INDArray x);
    INDArray backward(INDArray x);

}

BatchNormLayer.java


public interface BatchNormLayer extends Layer {

    public default INDArray forward(INDArray x) {
        throw new IllegalAccessError();
    }

    INDArray forward(INDArray x, boolean train_flg);

}

Le reste est l'implémentation de cette classe d'interface Dropout et Batch Normalization Implémentation de la classe. En outre, implémentez la classe MultiLayerNetExtend et exécutez le code de test suivant. Faire.

INDArray x_train;
INDArray t_train;
int max_epochs = 20;
int train_size;
int batch_size = 100;
double learning_rate = 0.01;
DataSet trainDataSet;

List<List<Double>> __train(String weight_init_std) {
    MultiLayerNetExtend bn_network = new MultiLayerNetExtend(
        784, new int[] {100, 100, 100, 100, 100}, 10,
        /*activation=*/"relu",
        /*weight_init_std=*/ weight_init_std,
        /*weight_decay_lambda=*/ 0,
        /*use_dropout=*/ false,
        /*dropout_ration=*/ 0.5,
        /*use_batchNorm=*/ true);
    MultiLayerNetExtend network = new MultiLayerNetExtend(
        784, new int[] {100, 100, 100, 100, 100}, 10,
        /*activation=*/"relu",
        /*weight_init_std=*/ weight_init_std,
        /*weight_decay_lambda=*/ 0,
        /*use_dropout=*/ false,
        /*dropout_ration=*/ 0.5,
        /*use_batchNorm=*/ false);
    List<MultiLayerNetExtend> networks = Arrays.asList(bn_network, network);
    Optimizer optimizer = new SGD(learning_rate);
    List<Double> train_acc_list = new ArrayList<>();
    List<Double> bn_train_acc_lsit = new ArrayList<>();
    int iter_per_epoch = Math.max(train_size / batch_size, 1);
    int epoch_cnt = 0;
    for (int i = 0; i < 1000000000; ++i) {
        DataSet sample = trainDataSet.sample(batch_size);
        INDArray x_batch = sample.getFeatureMatrix();
        INDArray t_batch = sample.getLabels();
        for (MultiLayerNetExtend _network : networks) {
            Params grads = _network.gradient(x_batch, t_batch);
            optimizer.update(_network.params, grads);
        }
        if (i % iter_per_epoch == 0) {
            double train_acc = network.accuracy(x_train, t_train);
            double bn_train_acc = bn_network.accuracy(x_train, t_train);
            train_acc_list.add(train_acc);
            bn_train_acc_lsit.add(bn_train_acc);
            System.out.println("epoch:" + epoch_cnt + " | " + train_acc + " - " + bn_train_acc);
            ++epoch_cnt;
            if (epoch_cnt >= max_epochs)
                break;
        }
    }
    return Arrays.asList(train_acc_list, bn_train_acc_lsit);
}

@Test
public void C6_3_2_Batch_Évaluation de la normalisation() throws IOException {
    // ch06/batch_norm_test.La version Java de py.
    MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
    x_train = train.normalizedImages();
    t_train = train.oneHotLabels();
    trainDataSet = new DataSet(x_train, t_train);
    train_size = x_train.size(0);

    //Dessiner un graphique
    File dir = Constants.WeightImages;
    if (!dir.exists()) dir.mkdirs();
    String[] names = {"BatchNormalization", "Normal"};
    Color[] colors = {Color.BLUE, Color.RED};
    INDArray weight_scale_list = Functions.logspace(0, -4, 16);
    INDArray x = Functions.arrange(max_epochs);
    for (int i = 0; i < weight_scale_list.length(); ++i) {
        System.out.println( "============== " + (i+1) + "/16" + " ==============");
        double w = weight_scale_list.getDouble(i);
        List<List<Double>> acc_list = __train(String.valueOf(w));
        GraphImage graph = new GraphImage(640, 480, -1, -0.1, 20, 1.0);
        for (int j = 0; j < names.length; ++j) {
            graph.color(colors[j]);
            graph.textInt(names[j] + " : " + w, 20, 20 * j + 20);
            graph.plot(0, acc_list.get(j).get(0));
            for (int k = 1; k < acc_list.get(j).size(); ++k) {
                graph.line(k - 1, acc_list.get(j).get(k - 1), k, acc_list.get(j).get(k));
                graph.plot(k, acc_list.get(j).get(k));
            }
        }
        File file = new File(dir, "BatchNormalization#" + w + ".png ");
        graph.writeTo(file);
    }
}

Le graphique résultant ressemble à ceci:

W=1.0 BatchNormalization#1.0.png

W=0.29286444187164307 BatchNormalization#0.29286444187164307.png

W=0.00009999999747378752 BatchNormalization#9.999999747378752E-5.png

Contrairement à ce livre, plus le W (écart type du poids initial) est petit, plus l'apprentissage progresse rapidement.

Recommended Posts

Deep Learning Java from scratch 6.3 Normalisation par lots
Deep Learning Java from scratch 6.4 Régularisation
Étudiez le Deep Learning à partir de zéro en Java.
Deep Learning Java à partir de zéro Chapitre 1 Introduction
Deep Learning Java from scratch 6.1 Mise à jour des paramètres
Deep Learning Java à partir de zéro Chapitre 2 Perceptron
Deep Learning Java à partir de zéro Chapitre 3 Réseau neuronal
Deep Learning Java from scratch 6.2 Valeur initiale du poids
Deep Learning Java from scratch Chapter 5 Méthode de propagation de retour d'erreur
Configuration PC la plus rapide pour un apprentissage en profondeur à partir de zéro
[Apprentissage profond à partir de zéro] 2. Il n'existe pas de NumPy en Java.
La vie Java à partir de zéro
[Deep Learning from scratch] en Java 1. Pour le moment, différenciation et différenciation partielle
Exécuter le fichier de commandes à partir de Java
Apprendre Java (0)
Java scratch scratch
Premiers pas pour l'apprentissage profond en Java
Jour d'apprentissage Java 5
java learning day 2
java learning day 1
Créer un environnement VS Code + WSL + Java + Gradle à partir de zéro
[Note] Créez un environnement Java à partir de zéro avec docker
Apprentissage rapide de Java "Introduction?" Partie 3 Parler de programmation
Appeler Java depuis JRuby
Changements de Java 8 à Java 11
Somme de Java_1 à 100
apprentissage java (expression conditionnelle)
Mémo d'apprentissage Java (méthode)
Évaluer la source Java à partir de Java
Interface d'historique d'apprentissage JAVA
Accédez à API.AI depuis Java
De Java à Ruby !!
Mémo d'apprentissage Java (interface)
Mémo d'apprentissage Java (héritage)
Enfant orienté objet!? J'ai essayé le Deep Learning avec Java (édition d'essai)
Voyons la Deep Java Library (DJL), une bibliothèque capable de gérer Deep Learning en Java, publiée par AWS.