Deep Learning Java von Grund auf 6.1 Parameteraktualisierung

Inhaltsverzeichnis

6.1.2 SGD

Definieren Sie zunächst die Schnittstelle Optimizer.

public interface Optimizer {

    void update(Params params, Params args);

}

Die Implementierung von SGD ist wie folgt.

public class SGD implements Optimizer {

    /** learning rate (Lernkoeffizient) */
    final double lr;

    public SGD(double lr) {
        this.lr = lr;
    }

    public SGD() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        params.update((p, g) -> p.subi(g.mul(lr)), grads);
    }
}

6.1.4 Momentum

public class Momentum implements Optimizer {

    final double lr, momentum;
    Params v;

    public Momentum(double lr, double momentum) {
        this.lr = lr;
        this.momentum = momentum;
        this.v = null;
    }

    public Momentum(double lr) {
        this(lr, 0.9);
    }

    public Momentum() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (v == null)
            v = Params.zerosLike(params);
        v.update((v, g) -> v.muli(momentum), grads);
        v.update((v, g) -> v.subi(g.mul(lr)), grads);
        params.update((p, v) -> p.addi(v), v);
    }
}

6.1.5 AdaGrad

public class AdaGrad implements Optimizer {

    final double lr;
    Params h;

    public AdaGrad(double lr) {
        this.lr = lr;
    }

    public AdaGrad() {
        this(0.01);
    }

    @Override
    public void update(Params params, Params grads) {
        if (h == null)
            h = Params.zerosLike(params);
        h.update((h, g) -> h.addi(g.mul(g)), grads);
        params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);
    }
}

6.1.6 Adam

public class Adam implements Optimizer {

    final double lr, beta1, beta2;
    int iter;
    Params m, v;

    public Adam(double lr, double beta1, double beta2) {
        this.lr = lr;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.iter = 0;
    }

    public Adam(double lr) {
        this(lr, 0.9, 0.999);
    }

    public Adam() {
        this(0.001);
    }

    @Override
    public void update(Params params, Params grads) {
        if (m == null) {
            m = Params.zerosLike(params);
            v = Params.zerosLike(params);
        }
        ++iter;
        double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
        m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
        v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
        params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);
    }

}

6.1.7 Welche Aktualisierungsmethode sollte verwendet werden?

Erstellen Sie eine einfache Klasse GraphImage, um ein Diagramm zu erstellen tat.

// ch06/optimizer_compare_naive.Eine Java-Version von py.
//Erstellen Sie ein Diagramm mit GraphImage.
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));

double[] init_pos = new double[] {-7.0, 2.0};
//Ursprünglicher Wert(0, 0)Die Entfernung von.
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
    .put("x", Nd4j.create(new double[] {init_pos[0]}))
    .put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
    .put("x", Nd4j.create(new double[] {0}))
    .put("y", Nd4j.create(new double[] {0}));

Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));

for (String key : optimizers.keySet()) {
    Optimizer optimizer = optimizers.get(key);
    params.put("x", Nd4j.create(new double[] {init_pos[0]}))
        .put("y", Nd4j.create(new double[] {init_pos[1]}));
    double min_distance = Double.MAX_VALUE;
    double last_distance = 0.0;
    double prevX = init_pos[0];
    double prevY = init_pos[1];
    try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
        //Zeichnet den Titel des Diagramms.
        image.text(key, -2, 7);
        //Zeichnen Sie den ersten Punkt.
        image.plot(prevX, prevY);
        for (int i = 0; i < 30; ++i) {
            INDArray temp = df.apply(params.get("x"), params.get("y"));
            grads.put("x", temp.getColumn(0));
            grads.put("y", temp.getColumn(1));
            optimizer.update(params, grads);
            double x = params.get("x").getDouble(0);
            double y = params.get("y").getDouble(0);
            last_distance = Math.hypot(x, y);
            if (last_distance < min_distance)
                min_distance = last_distance;
            //Zeichnen Sie eine Linie vom vorherigen Punkt.
            image.line(prevX, prevY, x, y);
            //Zeichnen Sie die Werte.
            image.plot(x, y);
            prevX = x;
            prevY = y;
        }
        //Stellen Sie sicher, dass es optimierter als der Anfangswert ist.
        assertTrue(last_distance < init_distance);
        assertTrue(min_distance < init_distance);
        //Geben Sie das Diagramm in eine Datei aus.
        image.writeTo(new File(outdir, key + ".png "));
    }
}

Das resultierende Diagramm sieht folgendermaßen aus:

SGD Momentum AdaGrad Adam
SGD.png Momentum.png AdaGrad.png Adam.png

6.1.8 Vergleich der Aktualisierungsmethoden mit dem MNIST-Datensatz

// ch06/optimizer_compare_mnist.Java-Version von py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();

int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;

// 1.Experimentelle Einstellungen
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10));
    train_loss.put(key, new ArrayList<>());
}
DataSet dataset = new DataSet(x_train, t_train);

// 2.Beginn des Trainings
for (int i = 0; i < max_iterations; ++i) {
    //Chargendaten extrahieren.
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    for (String key : optimizers.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizers.get(key).update(network.params, grads);
        double loss = network.loss(x_batch, t_batch);
        train_loss.get(key).add(loss);
    }
    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : optimizers.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);
        }
    }
}

// 3.Zeichnen eines Diagramms
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
    Map<String, Color> colors = new HashMap<>();
    colors.put("SGD", Color.GREEN);
    colors.put("Momentum", Color.BLUE);
    colors.put("AdaGrad", Color.RED);
    colors.put("Adam", Color.ORANGE);
    double w = 1300;
    double h = 0.7;
    for (String key : train_loss.keySet()) {
        List<Double> loss = train_loss.get(key);
        graph.color(colors.get(key));
        graph.text(key, w, h);
        h += 0.05;
        graph.plot(0, loss.get(0));
        int step = 10;
        for (int i = step, size = loss.size(); i < size; i += step) {
            graph.line(i - step, loss.get(i - step), i, loss.get(i));
            graph.plot(i, loss.get(i));
        }
    }
    graph.color(Color.BLACK);
    graph.text("Seite=Anzahl der Wiederholungen(0,2000)Vertikal=Wert der Verlustfunktion(0,1)", w, h);
    h += 0.05;
    graph.text("Vergleich von vier Aktualisierungsmethoden für MNIST-Datensätze", w, h);
    if (!Constants.OptimizerImages.exists())
        Constants.OptimizerImages.mkdirs();
    graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png "));
}

Das resultierende Diagramm sieht folgendermaßen aus: compare_mnist.png

Recommended Posts

Deep Learning Java von Grund auf 6.1 Parameteraktualisierung
Deep Learning Java von Grund auf 6.4 Regularisierung
Lernen Sie Deep Learning von Grund auf in Java.
Deep Learning Java von Grund auf neu Kapitel 1 Einführung
Deep Learning Java von Grund auf neu Kapitel 2 Perceptron
Deep Learning Java von Grund auf 6.3 Batch-Normalisierung
Deep Learning von Grund auf neu Java Kapitel 4 Lernen neuronaler Netze
[Deep Learning von Grund auf neu] in Java 3. Neuronales Netzwerk
Deep Learning Java von Grund auf neu Kapitel 3 Neuronales Netzwerk
Deep Learning Java von Grund auf 6.2 Anfangswert des Gewichts
[Deep Learning von Grund auf neu] 2. In Java gibt es kein NumPy.
Java-Leben von vorne anfangen
[Deep Learning von Grund auf neu] in Java 1. Zur Zeit Differenzierung und teilweise Differenzierung
Java lernen (0)
Java Scratch Scratch
Für JAVA-Lernen (2018-03-16-01)
Java-Lerntag 5
Java-Lerntag 2
Java-Lerntag 1
[Hinweis] Erstellen Sie mit Docker eine Java-Umgebung von Grund auf neu
Schnell lernen Java "Einführung?" Teil 3 Von der Programmierung wegreden
Änderungen von Java 8 zu Java 11
Summe von Java_1 bis 100
Java Learning 2 (Lernen Sie die Berechnungsmethode)
Java-Lernen (bedingter Ausdruck)
Java-Lernnotiz (Methode)
Eval Java-Quelle von Java
Java lernen (1) -Hallo Welt
Greifen Sie über Java auf API.AI zu
Java-Lernnotiz (grundlegend)
Von Java zu Ruby !!
Java-Lernnotiz (Schnittstelle)
Java-Lernnotiz (Vererbung)
Objektorientiertes Kind !? Ich habe Deep Learning mit Java ausprobiert (Testversion)
Lassen Sie uns auf Deep Java Library (DJL) eingehen, eine von AWS veröffentlichte Bibliothek, die Deep Learning in Java verarbeiten kann.