MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Lastdämpfung)Einstellungen von===========
double weight_decay_lambda = 0; //Wenn Sie keine Gewichtsabnahme verwenden
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
/*activation*/"relu", /*weight_init_std*/"relu",
/*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;
List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();
int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
DataSet sample = dataset.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
Params grads = network.gradient(x_batch, t_batch);
optimizer.update(network.params, grads);
if (i % iter_per_epoch == 0) {
double train_acc = network.accuracy(x_train, t_train);
double test_acc = network.accuracy(x_test, t_test);
train_acc_list.add(train_acc);
test_acc_list.add(test_acc);
System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
++epoch_cnt;
if (epoch_cnt >= max_epochs)
break;
}
}
// 3.Zeichnen eines Diagramms=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Erkennungsgenauigkeit beim Überlernen", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
graph.color(Color.BLUE);
graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
graph.plot(i, train_acc_list.get(i));
graph.color(Color.RED);
graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit.png "));
6.4.2 Weidht decay
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
assertEquals(300, x_train.size(0));
// weight decay (Lastdämpfung)Einstellungen von===========
// weight_decay_lambda = 0 //Wenn Sie keine Gewichtsabnahme verwenden
double weight_decay_lambda = 0.1;
MultiLayerNet network = new MultiLayerNet(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
/*activation*/"relu", /*weight_init_std*/"relu",
/*weight_decay_lambda*/weight_decay_lambda);
Optimizer optimizer = new SGD(0.01);
int max_epochs = 201;
int train_size = x_train.size(0);
int batch_size = 100;
List<Double> train_loss_list = new ArrayList<>();
List<Double> train_acc_list = new ArrayList<>();
List<Double> test_acc_list = new ArrayList<>();
int iter_per_epoch = Math.max(train_size / batch_size, 1);
int epoch_cnt = 0;
for (int i = 0; i < 1000000000; ++i) {
DataSet sample = dataset.sample(batch_size);
INDArray x_batch = sample.getFeatureMatrix();
INDArray t_batch = sample.getLabels();
Params grads = network.gradient(x_batch, t_batch);
optimizer.update(network.params, grads);
if (i % iter_per_epoch == 0) {
double train_acc = network.accuracy(x_train, t_train);
double test_acc = network.accuracy(x_test, t_test);
train_acc_list.add(train_acc);
test_acc_list.add(test_acc);
System.out.println("epoch:" + epoch_cnt + ", train acc:" + train_acc + ", test acc:" + test_acc);
++epoch_cnt;
if (epoch_cnt >= max_epochs)
break;
}
}
// 3.Zeichnen eines Diagramms=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Erkennungsgenauigkeit beim Überlernen mithilfe des Gewichtsabfalls", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
graph.color(Color.BLUE);
graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
graph.plot(i, train_acc_list.get(i));
graph.color(Color.RED);
graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "overfit_weight_decay.png "));
6.4.3 Dropout
MNISTImages train = new MNISTImages(Constants.TrainImages, Constants.TrainLabels);
//Reduzieren Sie die Trainingsdaten, um Übertraining zu reproduzieren
INDArray x_train = train.normalizedImages().get(NDArrayIndex.interval(0, 300));
INDArray t_train = train.oneHotLabels().get(NDArrayIndex.interval(0, 300));
DataSet dataset = new DataSet(x_train, t_train);
MNISTImages test = new MNISTImages(Constants.TestImages, Constants.TestLabels);
INDArray x_test = test.normalizedImages();
INDArray t_test = test.oneHotLabels();
//Vorhandensein / Nichtvorhandensein von Dropout, Einstellung des Verhältnisses====================
boolean use_dropout = true; //Falsch, wenn es keinen Ausfall gibt
double dropout_ratio = 0.2;
// =============================================
MultiLayerNetExtend network = new MultiLayerNetExtend(784, new int[] {100, 100, 100, 100, 100, 100}, 10,
/*activation*/"relu",
/*weight_init_std*/"relu",
/*weight_decay_lambda*/0,
/*use_dropout*/use_dropout, /*dropout_ratio*/dropout_ratio,
/*use_bachnorm*/false);
Trainer trainer = new Trainer(network, x_train, t_train, x_test, t_test,
/*epochs*/301,
/*mini_batch_size*/100,
/*optimizer*/() -> new SGD(0.01),
/*evaluate_sample_num_per_epoch*/0,
/*verbose*/true);
trainer.train();
List<Double> train_acc_list = trainer.train_acc_list;
List<Double> test_acc_list = trainer.test_acc_list;
// 3.Zeichnen eines Diagramms=============
GraphImage graph = new GraphImage(640, 480, -40, -0.1, 200, 1.0);
graph.color(Color.BLACK);
graph.textInt("Erkennungsgenauigkeit in Dropout", 10, 15);
graph.textInt("x=(" + graph.minX + "," + graph.maxX + ") y=(" + graph.minY + "," + graph.maxY + ")", 10, 30);
graph.color(Color.BLUE);
graph.textInt("train", 10, 45);
graph.plot(0, train_acc_list.get(0));
graph.color(Color.RED);
graph.textInt("test", 10, 60);
graph.plot(0, test_acc_list.get(0));
for (int i = 1; i < train_acc_list.size(); ++i) {
graph.color(Color.BLUE);
graph.line(i - 1, train_acc_list.get(i - 1), i, train_acc_list.get(i));
graph.plot(i, train_acc_list.get(i));
graph.color(Color.RED);
graph.line(i - 1, test_acc_list.get(i - 1), i, test_acc_list.get(i));
graph.plot(i, test_acc_list.get(i));
}
File dir = Constants.WeightImages;
if (!dir.exists()) dir.mkdirs();
graph.writeTo(new File(dir, "dropout.png "));
Recommended Posts