Deep Learning Java from scratch 6.2 Valeur initiale du poids

table des matières

6.2.2 Distribution d'activation des couches cachées

Un peu différent de ce livre, j'ai essayé de créer un graphique en changeant l'écart type de la valeur initiale de 4 façons.

// ch06/weight_init_activtion_histogram.La version Java de py.
INDArray x = Nd4j.randn(new int[] {1000, 100}); //1000 données
int node_num = 100; //Nombre de nœuds (neurones) dans chaque couche cachée
int hidden_layer_size = 5; //5 couches cachées
Map<Integer, INDArray> activations = new HashMap<>(); //Enregistrez le résultat de l'activation ici
Map<String, Supplier<INDArray>> ws = new LinkedHashMap<>();
//Expérimentons en changeant la valeur initiale!
ws.put("1.0", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(1));
ws.put("0.01", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(0.01));
ws.put("sqrt(1 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(1.0 / node_num)));
ws.put("sqrt(2 div n)", () -> Nd4j.randn(new int[] {node_num, node_num}).mul(Math.sqrt(2.0 / node_num)));
for (String key : ws.keySet()) {
    for (int i = 0; i < hidden_layer_size; ++i) {
        if (i != 0)
            x = activations.get(i - 1);
        INDArray w = ws.get(key).get();
        INDArray a = x.mmul(w);
        //Expérimentons en changeant le type de fonction d'activation!
        INDArray z = Functions.sigmoid(a);
        // INDArray z = Functions.relu(a);
        // INDArray z = Functions.tanh(a);
        activations.put(i, z);
    }
    //Dessinez un histogramme
    for (Entry<Integer, INDArray> e : activations.entrySet()) {
        HistogramImage h = new HistogramImage(320, 240, -0.1, -1000, 1, 40000, 50, e.getValue());
        h.writeTo(new File(Constants.WeightImages, key + "-" + (e.getKey() + 1) + "-layer.png "));
    }
}

Le résultat de l'utilisation de la fonction d'activation sigmoïde avec différents écarts types est le suivant.

Écart type = 0,01

1-layer 2-layer 3-layer 4-layer 5-layer
0.01-1-layer.png 0.01-2-layer.png 0.01-3-layer.png 0.01-4-layer.png 0.01-5-layer.png

Écart type = 1,0

1-layer 2-layer 3-layer 4-layer 5-layer
1.0-1-layer.png 1.0-2-layer.png 1.0-3-layer.png 1.0-4-layer.png 1.0-5-layer.png

Écart type = $ \ sqrt {\ frac {1} {n}} $

1-layer 2-layer 3-layer 4-layer 5-layer
sqrt(1 div n)-1-layer.png sqrt(1divn)-2-layer.png sqrt(1divn)-3-layer.png sqrt(1divn)-4-layer.png sqrt(1divn)-5-layer.png

6.2.4 Comparaison des valeurs de poids initiales par jeu de données MNIST

// ch06/weight_init_compare.La version Java de py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
DataSet dataset = new DataSet(x_train, t_train);
int train_size = x_train.size(0);
int batch_size = 128;
int max_iteration = 2000;

// 1:Paramètres expérimentaux
Map<String, String> weight_init_types = new HashMap<>();
weight_init_types.put("std", "0.01");
weight_init_types.put("Xavier", "sigmoid");
weight_init_types.put("He", "relu");
Optimizer optimizer = new SGD(0.01);

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (Entry<String, String> e : weight_init_types.entrySet()) {
    String key = e.getKey();
    String weight_init_std = e.getValue();
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10, weight_init_std));
    train_loss.put(key, new ArrayList<>());
}

//2:Début de la formation
for (int i = 0; i < max_iteration; ++i) {
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();

    for (String key : weight_init_types.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizer.update(network.params, grads);

        double loss = network.loss(x_batch, t_batch);
        train_loss.get(key).add(loss);
    }

    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : weight_init_types.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);
        }
    }
}

// 3:Dessiner un graphique
GraphImage graph = new GraphImage(800, 600, -100, -0.2, max_iteration, 2.5);
Map<String, Color> colors = new HashMap<>();
colors.put("std", Color.GREEN);
colors.put("Xavier", Color.RED);
colors.put("He", Color.BLUE);
double h = 1.5;
for (String key : weight_init_types.keySet()) {
    List<Double> losses = train_loss.get(key);
    graph.color(colors.get(key));
    graph.text(key, 1000, h);
    h += 0.1;
    int step = 10;
    graph.plot(0, losses.get(0));
    for (int i = step; i < max_iteration; i += step) {
        graph.line(i - step, losses.get(i - step), i, losses.get(i));
        graph.plot(i, losses.get(i));
    }
}
graph.color(Color.BLACK);
graph.text(String.format("x=(%f,%f),y=(%f,%f)",
    graph.minX, graph.maxX, graph.minY, graph.maxY), 1000, h);
h += 0.1;
graph.text("Comparaison par "poids initial" pour l'ensemble de données MNIST", 1000, h);
graph.writeTo(Constants.file(Constants.WeightImages, "weight_init_compare.png "));
}

Le graphique résultant ressemble à ceci:

weight_init_compare.png

Recommended Posts

Deep Learning Java from scratch 6.2 Valeur initiale du poids
Deep Learning Java from scratch 6.4 Régularisation
Étudiez le Deep Learning à partir de zéro en Java.
Deep Learning Java à partir de zéro Chapitre 1 Introduction
Deep Learning Java from scratch 6.1 Mise à jour des paramètres
Deep Learning Java à partir de zéro Chapitre 2 Perceptron
Deep Learning Java à partir de zéro Chapitre 3 Réseau neuronal
Deep Learning Java from scratch Chapter 5 Méthode de propagation de retour d'erreur
Configuration PC la plus rapide pour un apprentissage en profondeur à partir de zéro
[Apprentissage profond à partir de zéro] 2. Il n'existe pas de NumPy en Java.
[Deep Learning from scratch] en Java 1. Pour le moment, différenciation et différenciation partielle
java learning day 4
La vie Java à partir de zéro
Importance de l'interface apprise de la collection Java
Premiers pas pour l'apprentissage profond en Java
Java: Comment envoyer des valeurs du servlet au servlet
Apprendre Java (0)
Java scratch scratch
L'histoire de l'apprentissage de Java dans la première programmation
[Java] Obtenir plusieurs valeurs à partir d'une valeur de retour