Layer Étendu pour permettre à train_flg d'être transmis à l'interface BatchNormLayer a été créé.
Layer.java
public interface Layer {
INDArray forward(INDArray x);
INDArray backward(INDArray x);
}
BatchNormLayer.java
public interface BatchNormLayer extends Layer {
public default INDArray forward(INDArray x) {
throw new IllegalAccessError();
}
INDArray forward(INDArray x, boolean train_flg);
}
Le reste est l'implémentation de cette classe d'interface Dropout et Batch Normalization Implémentation de la classe. En outre, implémentez la classe MultiLayerNetExtend et exécutez le code de test suivant. Faire.
INDArray x_train;
INDArray t_train;
int max_epochs = 20;
int train_size;
int batch_size = 100;
double learning_rate = 0.01;
DataSet trainDataSet;
List<List<Double>> __train(String weight_init_std) {
MultiLayerNetExtend bn_network = new MultiLayerNetExtend(
784, new int[] {100, 100, 100, 100, 100}, 10,
/*activation=*/"relu",
/*weight_init_std=*/ weight_init_std,
/*weight_decay_lambda=*/ 0,
/*use_dropout=*/ false,
/*dropout_ration=*/ 0.5,
/*use_batchNorm=*/ true);
MultiLayerNetExtend network = new MultiLayerNetExtend(
784, new int[] {100, 100, 100, 100, 100}, 10,
/*activation=*/"relu",
/*weight_init_std=*/ weight_init_std,
/*weight_decay_lambda=*/ 0,
/*use_dropout=*/ false,
/*dropout_ration=*/ 0.5,
/*use_batchNorm=*/ false);
List<MultiLayerNetExtend> networks = Arrays.asList(bn_network, network);
Optimizer optimizer = new SGD(learning_rate);
List<Double> train_acc_list = new ArrayList<>();
List<Double> bn_train_acc_lsit = new ArrayList<>();
int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
DataSet sample = trainDataSet.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
for (MultiLayerNetExtend _network : networks) {
Params grads = _network.gradient(x_batch, t_batch);
optimizer.update(_network.params, grads);
}
if (i % iter_per_epoch == 0) {
double train_acc = network.accuracy(x_train, t_train);
double bn_train_acc = bn_network.accuracy(x_train, t_train);
train_acc_list.add(train_acc);
bn_train_acc_lsit.add(bn_train_acc);
System.out.println("epoch:" + epoch_cnt + " | " + train_acc + " - " + bn_train_acc);
++epoch_cnt;
if (epoch_cnt >= max_epochs)
break;
}
}
return Arrays.asList(train_acc_list, bn_train_acc_lsit);
}
@Test
public void C6_3_2_Batch_Évaluation de la normalisation() throws IOException {
// ch06/batch_norm_test.La version Java de py.
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
x_train = train.normalizedImages();
t_train = train.oneHotLabels();
trainDataSet = new DataSet(x_train, t_train);
train_size = x_train.size(0);
//Dessiner un graphique
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
String[] names = {"BatchNormalization", "Normal"};
Color[] colors = {Color.BLUE, Color.RED};
INDArray weight_scale_list = Functions.logspace(0, -4, 16);
INDArray x = Functions.arrange(max_epochs);
for (int i = 0; i < weight_scale_list.length(); ++i) {
System.out.println( "============== " + (i+1) + "/16" + " ==============");
double w = weight_scale_list.getDouble(i);
List<List<Double>> acc_list = __train(String.valueOf(w));
GraphImage graph = new GraphImage(640, 480, -1, -0.1, 20, 1.0);
for (int j = 0; j < names.length; ++j) {
graph.color(colors[j]);
graph.textInt(names[j] + " : " + w, 20, 20 * j + 20);
graph.plot(0, acc_list.get(j).get(0));
for (int k = 1; k < acc_list.get(j).size(); ++k) {
graph.line(k - 1, acc_list.get(j).get(k - 1), k, acc_list.get(j).get(k));
graph.plot(k, acc_list.get(j).get(k));
}
}
File file = new File(dir, "BatchNormalization#" + w + ".png ");
graph.writeTo(file);
}
}
Le graphique résultant ressemble à ceci:
W=1.0
W=0.29286444187164307
W=0.00009999999747378752
Contrairement à ce livre, plus le W (écart type du poids initial) est petit, plus l'apprentissage progresse rapidement.