6.1.2 SGD
Tout d'abord, définissez l'interface Optimizer.
public interface Optimizer {
void update(Params params, Params args);
}
La mise en œuvre de SGD est la suivante.
public class SGD implements Optimizer {
/** learning rate (Coefficient d'apprentissage) */
final double lr;
public SGD(double lr) {
this.lr = lr;
}
public SGD() {
this(0.01);
}
@Override
public void update(Params params, Params grads) {
params.update((p, g) -> p.subi(g.mul(lr)), grads);
}
}
6.1.4 Momentum
public class Momentum implements Optimizer {
final double lr, momentum;
Params v;
public Momentum(double lr, double momentum) {
this.lr = lr;
this.momentum = momentum;
this.v = null;
}
public Momentum(double lr) {
this(lr, 0.9);
}
public Momentum() {
this(0.01);
}
@Override
public void update(Params params, Params grads) {
if (v == null)
v = Params.zerosLike(params);
v.update((v, g) -> v.muli(momentum), grads);
v.update((v, g) -> v.subi(g.mul(lr)), grads);
params.update((p, v) -> p.addi(v), v);
}
}
6.1.5 AdaGrad
public class AdaGrad implements Optimizer {
final double lr;
Params h;
public AdaGrad(double lr) {
this.lr = lr;
}
public AdaGrad() {
this(0.01);
}
@Override
public void update(Params params, Params grads) {
if (h == null)
h = Params.zerosLike(params);
h.update((h, g) -> h.addi(g.mul(g)), grads);
params.update((p, g, h) -> p.subi(g.mul(lr).div(Transforms.sqrt(h).add(1e-7))), grads, h);
}
}
6.1.6 Adam
public class Adam implements Optimizer {
final double lr, beta1, beta2;
int iter;
Params m, v;
public Adam(double lr, double beta1, double beta2) {
this.lr = lr;
this.beta1 = beta1;
this.beta2 = beta2;
this.iter = 0;
}
public Adam(double lr) {
this(lr, 0.9, 0.999);
}
public Adam() {
this(0.001);
}
@Override
public void update(Params params, Params grads) {
if (m == null) {
m = Params.zerosLike(params);
v = Params.zerosLike(params);
}
++iter;
double lr_t = lr * Math.sqrt(1.0 - Math.pow(beta2, iter)) / (1.0 - Math.pow(beta1, iter));
m.update((m, g) -> m.addi(g.sub(m).mul(1 - beta1)), grads);
v.update((v, g) -> v.addi(g.mul(g).sub(v).mul(1 - beta2)), grads);
params.update((p, m, v) -> p.subi(m.mul(lr_t).div(Transforms.sqrt(v).add(1e-7))), m, v);
}
}
Créez une classe simple GraphImage pour créer un graphique fait.
// ch06/optimizer_compare_naive.Une version java de py.
//Créez un graphique à l'aide de GraphImage.
File outdir = Constants.OptimizerImages;
if (!outdir.exists()) outdir.mkdirs();
// BinaryOperator<INDArray> f = (x, y) ->
// x.mul(x).div(y.mul(y).add(20.0));
BinaryOperator<INDArray> df = (x, y) -> Nd4j.concat(1, x.div(10.0), y.mul(2.0));
double[] init_pos = new double[] {-7.0, 2.0};
//Valeur initiale(0, 0)La distance de.
double init_distance = Math.hypot(init_pos[0], init_pos[1]);
Params params = new Params()
.put("x", Nd4j.create(new double[] {init_pos[0]}))
.put("y", Nd4j.create(new double[] {init_pos[1]}));
Params grads = new Params()
.put("x", Nd4j.create(new double[] {0}))
.put("y", Nd4j.create(new double[] {0}));
Map<String, Optimizer> optimizers = new LinkedHashMap<>();
optimizers.put("SGD", new SGD(0.95));
optimizers.put("Momentum", new Momentum(0.1));
optimizers.put("AdaGrad", new AdaGrad(1.5));
optimizers.put("Adam", new Adam(0.3));
for (String key : optimizers.keySet()) {
Optimizer optimizer = optimizers.get(key);
params.put("x", Nd4j.create(new double[] {init_pos[0]}))
.put("y", Nd4j.create(new double[] {init_pos[1]}));
double min_distance = Double.MAX_VALUE;
double last_distance = 0.0;
double prevX = init_pos[0];
double prevY = init_pos[1];
try (GraphImage image = new GraphImage(700, 700, -10, -10, 10, 10)) {
//Dessine le titre du graphique.
image.text(key, -2, 7);
//Tracez le premier point.
image.plot(prevX, prevY);
for (int i = 0; i < 30; ++i) {
INDArray temp = df.apply(params.get("x"), params.get("y"));
grads.put("x", temp.getColumn(0));
grads.put("y", temp.getColumn(1));
optimizer.update(params, grads);
double x = params.get("x").getDouble(0);
double y = params.get("y").getDouble(0);
last_distance = Math.hypot(x, y);
if (last_distance < min_distance)
min_distance = last_distance;
//Tracez une ligne à partir du point précédent.
image.line(prevX, prevY, x, y);
//Tracez les valeurs.
image.plot(x, y);
prevX = x;
prevY = y;
}
//Assurez-vous qu'il est plus optimisé que la valeur initiale.
assertTrue(last_distance < init_distance);
assertTrue(min_distance < init_distance);
//Exportez le graphique dans un fichier.
image.writeTo(new File(outdir, key + ".png "));
}
}
Le graphique résultant ressemble à ceci:
SGD | Momentum | AdaGrad | Adam |
---|---|---|---|
// ch06/optimizer_compare_mnist.La version Java de py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages();
INDArray t_train = train.oneHotLabels();
int train_size = x_train.size(0);
int batch_size = 128;
int max_iterations = 2000;
// 1.Paramètres expérimentaux
Map<String, Optimizer> optimizers = new HashMap<>();
optimizers.put("SGD", new SGD());
optimizers.put("Momentum", new Momentum());
optimizers.put("AdaGrad", new AdaGrad());
optimizers.put("Adam", new Adam());
// optimizers.put("RMSprop", new RMSprop());
Map<String, MultiLayerNet> networks = new HashMap<>();
Map<String, List<Double>> train_loss = new HashMap<>();
for (String key : optimizers.keySet()) {
networks.put(key, new MultiLayerNet(
784, new int[] {100, 100, 100, 100}, 10));
train_loss.put(key, new ArrayList<>());
}
DataSet dataset = new DataSet(x_train, t_train);
// 2.Début de la formation
for (int i = 0; i < max_iterations; ++i) {
//Extraire les données de lot.
DataSet sample = dataset.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
for (String key : optimizers.keySet()) {
MultiLayerNet network = networks.get(key);
Params grads = network.gradicent(x_batch, t_batch);
optimizers.get(key).update(network.params, grads);
double loss = network.loss(x_batch, t_batch);
train_loss.get(key).add(loss);
}
if (i % 100 == 0) {
System.out.println("===========" + "iteration:" + i + "===========");
for (String key : optimizers.keySet()) {
double loss = networks.get(key).loss(x_batch, t_batch);
System.out.println(key + ":" + loss);
}
}
}
// 3.Dessiner un graphique
try (GraphImage graph = new GraphImage(1000, 800, -100, -0.1, 2000, 1.0)) {
Map<String, Color> colors = new HashMap<>();
colors.put("SGD", Color.GREEN);
colors.put("Momentum", Color.BLUE);
colors.put("AdaGrad", Color.RED);
colors.put("Adam", Color.ORANGE);
double w = 1300;
double h = 0.7;
for (String key : train_loss.keySet()) {
List<Double> loss = train_loss.get(key);
graph.color(colors.get(key));
graph.text(key, w, h);
h += 0.05;
graph.plot(0, loss.get(0));
int step = 10;
for (int i = step, size = loss.size(); i < size; i += step) {
graph.line(i - step, loss.get(i - step), i, loss.get(i));
graph.plot(i, loss.get(i));
}
}
graph.color(Color.BLACK);
graph.text("côté=Nombre de répétitions(0,2000)Verticale=Valeur de la fonction de perte(0,1)", w, h);
h += 0.05;
graph.text("Comparaison de quatre méthodes de mise à jour pour les ensembles de données MNIST", w, h);
if (!Constants.OptimizerImages.exists())
Constants.OptimizerImages.mkdirs();
graph.writeTo(new File(Constants.OptimizerImages, "compare_mnist.png "));
}
Le graphique résultant ressemble à ceci: