Deep Learning Java from scratch 6.1 Mise à jour des paramètres

table des matières

6.1.2 SGD

Tout d'abord, définissez l'interface Optimizer.

public interface Optimizer {

    void update(Params params, Params args);

}

La mise en œuvre de SGD est la suivante.

public class SGD implements Optimizer {

    /** learning rate (Coefficient d'apprentissage) */
    final double lr;

    public SGD(double lr) {
        this.lr = lr;
    }

    public SGD() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        params.update((p, g) -> p.subi(g.mul(lr)), grads);
    }
}

6.1.4 Momentum

public class Momentum implements Optimizer {

    final double lr, momentum;
    Params v;

    public Momentum(double lr, double momentum) {
        this.lr = lr;
        this.momentum = momentum;
        this.v = null;
    }

    public Momentum(double lr) {
        this(lr, 0.9);
    }

    public Momentum() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (v == null)
            v = Params.zerosLike(params);
        v.update((v, g) -> v.muli(momentum), grads);
        v.update((v, g) -> v.subi(g.mul(lr)), grads);
        params.update((p, v) -> p.addi(v), v);
    }
}

6.1.5 AdaGrad

public class AdaGrad implements Optimizer {

    final double lr;
    Params h;

    public AdaGrad(double lr) {
        this.lr = lr;
    }

    public AdaGrad() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (h == null)
            h = Params.zerosLike(params);
        h.update((h, g) -> h.addi(g.mul(g)), grads);
        params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);
    }
}

6.1.6 Adam

public class Adam implements Optimizer {

    final double lr, beta1, beta2;
    int iter;
    Params m, v;

    public Adam(double lr, double beta1, double beta2) {
        this.lr = lr;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.iter = 0;
    }

    public Adam(double lr) {
        this(lr, 0.9, 0.999);
    }

    public Adam() {
        this(0.001);
    }

    @Override
    public void update(Params params, Params grads) {
        if (m == null) {
            m = Params.zerosLike(params);
            v = Params.zerosLike(params);
        }
        ++iter;
        double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
        m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
        v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
        params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);
    }

}

6.1.7 Quelle méthode de mise à jour faut-il utiliser?

Créez une classe simple GraphImage pour créer un graphique fait.

// ch06/optimizer_compare_naive.Une version java de py.
//Créez un graphique à l'aide de GraphImage.
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));

double[] init_pos = new double[] {-7.0, 2.0};
//Valeur initiale(0, 0)La distance de.
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
    .put("x", Nd4j.create(new double[] {init_pos[0]}))
    .put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
    .put("x", Nd4j.create(new double[] {0}))
    .put("y", Nd4j.create(new double[] {0}));

Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));

for (String key : optimizers.keySet()) {
    Optimizer optimizer = optimizers.get(key);
    params.put("x", Nd4j.create(new double[] {init_pos[0]}))
        .put("y", Nd4j.create(new double[] {init_pos[1]}));
    double min_distance = Double.MAX_VALUE;
    double last_distance = 0.0;
    double prevX = init_pos[0];
    double prevY = init_pos[1];
    try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
        //Dessine le titre du graphique.
        image.text(key, -2, 7);
        //Tracez le premier point.
        image.plot(prevX, prevY);
        for (int i = 0; i < 30; ++i) {
            INDArray temp = df.apply(params.get("x"), params.get("y"));
            grads.put("x", temp.getColumn(0));
            grads.put("y", temp.getColumn(1));
            optimizer.update(params, grads);
            double x = params.get("x").getDouble(0);
            double y = params.get("y").getDouble(0);
            last_distance = Math.hypot(x, y);
            if (last_distance < min_distance)
                min_distance = last_distance;
            //Tracez une ligne à partir du point précédent.
            image.line(prevX, prevY, x, y);
            //Tracez les valeurs.
            image.plot(x, y);
            prevX = x;
            prevY = y;
        }
        //Assurez-vous qu'il est plus optimisé que la valeur initiale.
        assertTrue(last_distance < init_distance);
        assertTrue(min_distance < init_distance);
        //Exportez le graphique dans un fichier.
        image.writeTo(new File(outdir, key + ".png "));
    }
}

Le graphique résultant ressemble à ceci:

SGD Momentum AdaGrad Adam
SGD.png Momentum.png AdaGrad.png Adam.png

6.1.8 Comparaison des méthodes de mise à jour utilisant le jeu de données MNIST

// ch06/optimizer_compare_mnist.La version Java de py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();

int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;

// 1.Paramètres expérimentaux
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10));
    train_loss.put(key, new ArrayList<>());
}
DataSet dataset = new DataSet(x_train, t_train);

// 2.Début de la formation
for (int i = 0; i < max_iterations; ++i) {
    //Extraire les données de lot.
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    for (String key : optimizers.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizers.get(key).update(network.params, grads);
        double loss = network.loss(x_batch, t_batch);
        train_loss.get(key).add(loss);
    }
    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : optimizers.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);
        }
    }
}

// 3.Dessiner un graphique
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
    Map<String, Color> colors = new HashMap<>();
    colors.put("SGD", Color.GREEN);
    colors.put("Momentum", Color.BLUE);
    colors.put("AdaGrad", Color.RED);
    colors.put("Adam", Color.ORANGE);
    double w = 1300;
    double h = 0.7;
    for (String key : train_loss.keySet()) {
        List<Double> loss = train_loss.get(key);
        graph.color(colors.get(key));
        graph.text(key, w, h);
        h += 0.05;
        graph.plot(0, loss.get(0));
        int step = 10;
        for (int i = step, size = loss.size(); i < size; i += step) {
            graph.line(i - step, loss.get(i - step), i, loss.get(i));
            graph.plot(i, loss.get(i));
        }
    }
    graph.color(Color.BLACK);
    graph.text("côté=Nombre de répétitions(0,2000)Verticale=Valeur de la fonction de perte(0,1)", w, h);
    h += 0.05;
    graph.text("Comparaison de quatre méthodes de mise à jour pour les ensembles de données MNIST", w, h);
    if (!Constants.OptimizerImages.exists())
        Constants.OptimizerImages.mkdirs();
    graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png "));
}

Le graphique résultant ressemble à ceci: compare_mnist.png

Recommended Posts

Deep Learning Java from scratch 6.1 Mise à jour des paramètres
Deep Learning Java from scratch 6.4 Régularisation
Étudiez le Deep Learning à partir de zéro en Java.
Deep Learning Java à partir de zéro Chapitre 1 Introduction
Deep Learning Java à partir de zéro Chapitre 2 Perceptron
Deep Learning Java from scratch 6.3 Normalisation par lots
Deep Learning from scratch Java Chapter 4 Apprentissage des réseaux de neurones
[Deep Learning from scratch] dans Java 3. Réseau neuronal
Deep Learning Java à partir de zéro Chapitre 3 Réseau neuronal
Deep Learning Java from scratch 6.2 Valeur initiale du poids
[Apprentissage profond à partir de zéro] 2. Il n'existe pas de NumPy en Java.
La vie Java à partir de zéro
[Deep Learning from scratch] en Java 1. Pour le moment, différenciation et différenciation partielle
Apprendre Java (0)
Java scratch scratch
Pour l'apprentissage JAVA (2018-03-16-01)
Jour d'apprentissage Java 5
java learning day 2
java learning day 1
[Note] Créez un environnement Java à partir de zéro avec docker
Apprentissage rapide de Java "Introduction?" Partie 3 Parler de programmation
Changements de Java 8 à Java 11
Somme de Java_1 à 100
Java Learning 2 (Apprenez la méthode de calcul)
apprentissage java (expression conditionnelle)
Mémo d'apprentissage Java (méthode)
Évaluer la source Java à partir de Java
Apprendre Java (1) - Hello World
Interface d'historique d'apprentissage JAVA
Accédez à API.AI depuis Java
Mémo d'apprentissage Java (basique)
De Java à Ruby !!
Mémo d'apprentissage Java (interface)
Mémo d'apprentissage Java (héritage)
Enfant orienté objet!? J'ai essayé le Deep Learning avec Java (édition d'essai)
Voyons la Deep Java Library (DJL), une bibliothèque capable de gérer Deep Learning en Java, publiée par AWS.