Deep Learning Java from scratch 6.1 Parameter update

table of contents

6.1.2 SGD

First, define the Optimizer interface.

public interface Optimizer {

    void update(Params params, Params args);


The implementation of SGD is as follows.

public class SGD implements Optimizer {

    /** learning rate (Learning coefficient) */
    final double lr;

    public SGD(double lr) { = lr;

    public SGD() {

    public void update(Params params, Params grads) {
        params.update((p, g) -> p.subi(g.mul(lr)), grads);

6.1.4 Momentum

public class Momentum implements Optimizer {

    final double lr, momentum;
    Params v;

    public Momentum(double lr, double momentum) { = lr;
        this.momentum = momentum;
        this.v = null;

    public Momentum(double lr) {
        this(lr, 0.9);

    public Momentum() {

    public void update(Params params, Params grads) {
        if (v == null)
            v = Params.zerosLike(params);
        v.update((v, g) -> v.muli(momentum), grads);
        v.update((v, g) -> v.subi(g.mul(lr)), grads);
        params.update((p, v) -> p.addi(v), v);

6.1.5 AdaGrad

public class AdaGrad implements Optimizer {

    final double lr;
    Params h;

    public AdaGrad(double lr) { = lr;

    public AdaGrad() {

    public void update(Params params, Params grads) {
        if (h == null)
            h = Params.zerosLike(params);
        h.update((h, g) -> h.addi(g.mul(g)), grads);
        params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);

6.1.6 Adam

public class Adam implements Optimizer {

    final double lr, beta1, beta2;
    int iter;
    Params m, v;

    public Adam(double lr, double beta1, double beta2) { = lr;
        this.beta1 = beta1;
        this.beta2 = beta2;
        this.iter = 0;

    public Adam(double lr) {
        this(lr, 0.9, 0.999);

    public Adam() {

    public void update(Params params, Params grads) {
        if (m == null) {
            m = Params.zerosLike(params);
            v = Params.zerosLike(params);
        double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
        m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
        v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
        params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);


6.1.7 Which update method should be used?

Create a simple class GraphImage to create a graph did.

// ch06/optimizer_compare_naive.The java version of py.
//Create a graph using GraphImage.
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));

double[] init_pos = new double[] {-7.0, 2.0};
//Initial value(0, 0)The distance from.
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
    .put("x", Nd4j.create(new double[] {init_pos[0]}))
    .put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
    .put("x", Nd4j.create(new double[] {0}))
    .put("y", Nd4j.create(new double[] {0}));

Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));

for (String key : optimizers.keySet()) {
    Optimizer optimizer = optimizers.get(key);
    params.put("x", Nd4j.create(new double[] {init_pos[0]}))
        .put("y", Nd4j.create(new double[] {init_pos[1]}));
    double min_distance = Double.MAX_VALUE;
    double last_distance = 0.0;
    double prevX = init_pos[0];
    double prevY = init_pos[1];
    try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
        //Draws the title of the graph.
        image.text(key, -2, 7);
        //Plot the first point.
        image.plot(prevX, prevY);
        for (int i = 0; i < 30; ++i) {
            INDArray temp = df.apply(params.get("x"), params.get("y"));
            grads.put("x", temp.getColumn(0));
            grads.put("y", temp.getColumn(1));
            optimizer.update(params, grads);
            double x = params.get("x").getDouble(0);
            double y = params.get("y").getDouble(0);
            last_distance = Math.hypot(x, y);
            if (last_distance < min_distance)
                min_distance = last_distance;
            //Draw a line from the previous point.
            image.line(prevX, prevY, x, y);
            //Plot the values.
            image.plot(x, y);
            prevX = x;
            prevY = y;
        //Make sure that it is optimized over the initial value.
        assertTrue(last_distance < init_distance);
        assertTrue(min_distance < init_distance);
        //Output the graph to a file.
        image.writeTo(new File(outdir, key + ".png "));

The resulting graph looks like this:

SGD Momentum AdaGrad Adam
SGD.png Momentum.png AdaGrad.png Adam.png

6.1.8 Comparison of update methods with MNIST dataset

// ch06/optimizer_compare_mnist.Java version of py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();

int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;

// 1.Experiment settings
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());

Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
    networks.put(key, new MultiLayerNet(
        784, new int[] {100, 100, 100, 100}, 10));
    train_loss.put(key, new ArrayList<>());
DataSet dataset = new DataSet(x_train, t_train);

// 2.Start of training
for (int i = 0; i < max_iterations; ++i) {
    //Extract batch data.
    DataSet sample = dataset.sample(batch_size);
    INDArray x_batch = sample.getFeatureMatrix();
    INDArray t_batch = sample.getLabels();
    for (String key : optimizers.keySet()) {
        MultiLayerNet network = networks.get(key);
        Params grads = network.gradicent(x_batch, t_batch);
        optimizers.get(key).update(network.params, grads);
        double loss = network.loss(x_batch, t_batch);
    if (i % 100 == 0) {
        System.out.println("===========" + "iteration:" + i + "===========");
        for (String key : optimizers.keySet()) {
            double loss = networks.get(key).loss(x_batch, t_batch);
            System.out.println(key + ":" + loss);

// 3.Drawing a graph
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
    Map<String, Color> colors = new HashMap<>();
    colors.put("SGD", Color.GREEN);
    colors.put("Momentum", Color.BLUE);
    colors.put("AdaGrad", Color.RED);
    colors.put("Adam", Color.ORANGE);
    double w = 1300;
    double h = 0.7;
    for (String key : train_loss.keySet()) {
        List<Double> loss = train_loss.get(key);
        graph.text(key, w, h);
        h += 0.05;
        graph.plot(0, loss.get(0));
        int step = 10;
        for (int i = step, size = loss.size(); i < size; i += step) {
            graph.line(i - step, loss.get(i - step), i, loss.get(i));
            graph.plot(i, loss.get(i));
    graph.text("side=Number of repetitions(0,2000)Vertical=Loss function value(0,1)", w, h);
    h += 0.05;
    graph.text("Comparison of four update methods for MNIST datasets", w, h);
    if (!Constants.OptimizerImages.exists())
    graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png "));

The resulting graph looks like this: compare_mnist.png

Recommended Posts

Deep Learning Java from scratch 6.1 Parameter update
Deep Learning Java from scratch 6.4 Regularization
Study Deep Learning from scratch in Java.
Deep Learning Java from scratch Chapter 1 Introduction
Deep Learning Java from scratch Chapter 2 Perceptron
Deep Learning Java from scratch 6.3 Batch Normalization
Deep Learning from scratch Java Chapter 4 Neural network learning
[Deep Learning from scratch] in Java 3. Neural network
Deep Learning Java from scratch Chapter 3 Neural networks
Deep Learning Java from scratch 6.2 Initial values of weights
[Deep Learning from scratch] 2. There is no such thing as NumPy in Java.
Java life starting from scratch
[Deep Learning from scratch] in Java 1. For the time being, differentiation and partial differentiation
Java learning (0)
Java scratch scratch
For JAVA learning (2018-03-16-01)
Java learning day 5
I tried to implement deep learning in Java
java learning day 2
java learning day 1
[Note] Create a java environment from scratch with docker
Quick learning Java "Introduction?" Part 3 Talking away from programming
Changes from Java 8 to Java 11
Sum from Java_1 to 100
Java learning 2 (learning calculation method)
java learning (conditional expression)
Java learning memo (method)
Eval Java source from Java
Java Learning (1)-Hello World
JAVA learning history interface
Access API.AI from Java
Java learning memo (basic)
From Java to Ruby !!
Java learning memo (interface)
Java learning memo (inheritance)
Object-oriented child !? I tried Deep Learning in Java (trial edition)
Let's touch on Deep Java Library (DJL), a library that can handle Deep Learning in Java released from AWS.