[Java] Creation of automatic classifier of chest and abdominal X-ray images for deep learning quick start

17 minute read

Why did you write an article

We held a study session two years ago, but since it is already old, we will publish it as a memorandum.

What I tried

Develop an application that can automatically distinguish between chest X-ray images and abdominal X-ray images.

Motivation

It’s important to classify something as binary. No matter how complex an issue is, if you can systematically organize it, you can break it down into parts. If you break down the decomposed issues into smaller pieces, you will eventually select Yes or No. By subdividing the parts and providing answers, it is possible to solve difficult problems by repeating Yes or No. Any theme was fine, but if you’re looking for a training subject for a medical binary image classification experience, this one was just right. Originally this paper. Hello World Deep Learning in Medical Imaging

Used

  • Laptop (general)
  • Optional: NVIDIA-GPU (1050Ti this time)

Environment

  • ubuntu 18.04 (because it is java, the OS does not matter)
  • maven + dl4j related
  • eclipse (2018)
  • JDK8 or higher
  • (Already, people who are interested in such a theme are doing it with python)

data

It is published on the GitHub link below. https://github.com/paras42/Hello_World_Deep_Learning/tree/9921a12c905c00a88898121d5dc538e3b524e520 The image is “Open_I_abd_vs_CXRs.zip”. Abd stands for Abdomen and CXRs stands for Chest-X-rays. maybe. After downloading, unzip and use. There are 75 images in all, 38 chest X-rays and 37 abdominal X-rays. The folder hierarchy looks like this. It is divided into TEST, TRAIN, and VAL (abbreviation of Validation) folders. The TRAIN and VAL folders have image folders for the chest and abdomen, respectively. ![Screenshot from 2020-07-27 22-47-38.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/379410/e8ef6287-1661-5662-(646d-3e0498d89edf.png) The TEST folder contains one chest and one abdominal image, and these are not sorted into folders.

Work

After decompressing the data, save it in an appropriate location. In my case, I placed it directly under the Maven project. ![Screenshot from 2020-07-31 10-48-59.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/379410/31e06aa0-6b33-7590-(8bf6-5c2520a94847.png)

Code and commentary

POM.xml is described at the end of this page. The final code is summarized at the end of this section. Please note that the package you are viewing may differ depending on the version. (Impression that changes frequently…)

setup

First, prepare the basic parameters and settings for learning.

During learning, random variables are used in various places such as weight calculation and automatic allocation of training data. This is very handy, but it’s annoying if the results change each time. Define a seed to set the same random variable every time.

long seed = 42;
final Random RAND_NUM_GEN = new Random(seed);

Since we are targeting images this time, you need to enter images. Enter any image! If so, strange data that happened to be in the folder may be sucked in. To prevent this, set the image format that can be input. By default, general image formats can be input here.

final String[] ALLOWED_FORMATS = BaseImageLoader.ALLOWED_FORMATS;

In machine learning (with teacher), teacher label data is often created by hand, but if you set it as follows, the folder name will be automatically recognized as a class name and the label will be sorted automatically. (In this case, for example, chest image 1.png: 0,1,followtheorderinwhichthelabelsarearranged.However,”1”isgenerallyused.)

ParentPathLabelGenerator LABEL_GENERATOR_MAKER = new ParentPathLabelGenerator();

Next, when sending the data to the learning, set it so that the same number is input while randomly selecting the data.

BalancedPathFilter PATH_FILTER = new BalancedPathFilter(RAND_NUM_GEN, ALLOWED_FORMATS, LABEL_GENERATOR_MAKER);

Make the basic settings required to learn the model. As you can see in the comments, numLabels is the number of labels. This time there are 2 labels because it is the classification of chest and abdomen. For height, width, and channels, set the vertical and horizontal matrix of the image (the image you want to predict) input to the model and the color channel. inputShape is an array that combines these and becomes the setting value of the input layer of the model. batchSize is the amount of data used in one training, and the network weight is updated after processing this data. epochs is the number of learnings. The data of batchSize is learned by one learning, and the weight of the network is updated.

int numLabels = 2;// chest or abd
int height = 64; // image size for train
int width = 64; // image size for train
int channels = 3;// image channels(in this case, image type is RGB, so 3 channels)
int[] inputShape = new int[] {channels, height, width};
int batchSize = 32;// train data size in 1 epoch
int epochs = 50;

Image data input pipeline

Now that you have completed the basic settings, set the method for inputting images. Build the FileSplit and InputSplit objects by specifying the path to the learning data folder you want to input. Originally, these are used to automatically sort images for training/verification/test, but this time, since the sorting has been completed in folders, we do not divide the data by code, training We are building an input pipeline for validation, validation, and testing.

System.out.println("Preparing data....");
// Prepare train
File trainDir = new File("./Open_I_abd_vs_CXRs/TRAIN/");
FileSplit trainSplit = new FileSplit(trainDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
InputSplit train = trainSplit.sample(PATH_FILTER, 1.0)[0]; //Train everything
// Prepare val
File valDir = new File("./Open_I_abd_vs_CXRs/VAL/");
FileSplit valSplit = new FileSplit(valDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
InputSplit val = valSplit.sample(PATH_FILTER, 1.0)[0]; //Verify everything
// Prepare test
File testDir = new File("./Open_I_abd_vs_CXRs/TEST/");
FileSplit testSplit = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
InputSplit test = testSplit.sample(PATH_FILTER, 1.0)[0]; //All to test
It's a sequel.
System.out.println("train data total sample size "+ train.length());
System.out.println("validation total data sample size "+ val.length());
System.out.println("test data total sample size "+ test.length());

Augmentation (amplification of pseudo data)

Since the amount of data in this dataset is very small (hundreds of units of data are required for each class in deep learning), we will increase the data in a pseudo manner and examine the accuracy of the model. If this goes well, it seems that we can develop a model that follows a good line even if we increase the amount of data! To understand. Various methods can be used to amplify the image, such as flip, rotate, crop, slide position, and transform by affine transformation. The caveat is to avoid amplifying impossible images. For example, in an ultrasonic image, there is a backward echo, but there is a failure example of making a pseudo image by rotating 180°. I didn’t really think so hard, but I used ImageTransform to set random flips and position translations. Create several ImageTransforms, finally combine them into a List and build them as a PipelineImageTransform to create a pipeline. If shuffle of PipelineImageTransform is True, the pipeline order is randomly selected. If False, it will be processed sequentially in List order.

System.out.println("Prepare augumentation....");ImageTransform flipTransform1 = new FlipImageTransform(new Random(seed));
ImageTransform flipTransform2 = new FlipImageTransform(new Random(seed));
ImageTransform warpTransform = new WarpImageTransform(new Random(seed), inputShape[1]/10);
boolean shuffle = false;
List<Pair<ImageTransform, Double>> pipeline = Arrays.asList(new Pair<>(flipTransform1, 0.9),
new Pair<>(flipTransform2, 0.8), new Pair<>(warpTransform, 0.9));
ImageTransform transform = new PipelineImageTransform(pipeline, shuffle);

Linking image input and data amplification

Up to this point, the input part of the image and the setting of the data enhancement process have been completed. After that, I will link these. Generally, you will only augment the training data. Specify whether to enhance the input of the image like the following code. ImageRecordReader manages this image input and augmentation process.

// data reader setup
ImageRecordReader recordReaderTrain = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
ImageRecordReader recordReaderVal = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
/*
* Since this time it will match the data structure of the distribution source,
* Test data does not automatically calculate the label of hierarchy.
* (When using, use the same folder hierarchy as the data.)
*/
// ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels);
try {
// recordReaderTrain.initialize(train);// Train without transformations
recordReaderTrain.initialize(train,transform);// Train with transformations
recordReaderVal.initialize(val); //Do not augment the validation data
recordReaderTest.initialize(test);
} catch (IOException e) {
e.printStackTrace();
}

Model building

I wanted to make it a little easier, but since it is a big deal, I will try to borrow a network of Model Coo series called Simple CNN. The example shown here is not a complete SimpleCNN, but the last output layer is tuned and added for this discussion. I haven’t done anything difficult, I just copied and copied it from the SimpleCNN.java code and made the output layer for multi-class classification (since it is 2 classes, binary classification is possible, but I will use SoftMax here). The description here is omitted, but in DL4J, the MultiLayer Network used here is the basic and simple concept of CNN. There are more complicated and huge models that those who are interested in deep learning are familiar with, but such complicated and huge models are constructed by combining this MultiLayerNetwork.

System.out.println("Start construct SimpleCNN model...");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().trainingWorkspaceMode(WorkspaceMode.ENABLED)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED).seed(seed).activation(Activation.IDENTITY)
.weightInit(WeightInit.RELU).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new AdaDelta()).convolutionMode(ConvolutionMode.Same).list()
// block 1
.layer(0,
new ConvolutionLayer.Builder(new int[] {7,7}).name("image_array").nIn(inputShape[0]).nOut(16)
.build())
.layer(1, new BatchNormalization.Builder().build())
.layer(2, new ConvolutionLayer.Builder(new int[] {7, 7 }).nIn(16).nOut(16).build())
.layer(3, new BatchNormalization.Builder().build())
.layer(4, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(6, new DropoutLayer.Builder(0.5).build())

// block 2
.layer(7, new ConvolutionLayer.Builder(new int[] {5, 5 }).nOut(32).build())
.layer(8, new BatchNormalization.Builder().build())
.layer(9, new ConvolutionLayer.Builder(new int[] {5, 5 }).nOut(32).build())
.layer(10, new BatchNormalization.Builder().build())
.layer(11, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(12, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(13, new DropoutLayer.Builder(0.5).build())

// block 3
.layer(14, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(64).build())
.layer(15, new BatchNormalization.Builder().build())
.layer(16, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(64).build())
.layer(17, new BatchNormalization.Builder().build())
.layer(18, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(19, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(20, new DropoutLayer.Builder(0.5).build())

// block 4
.layer(21, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(128).build())
.layer(22, new BatchNormalization.Builder().build())
.layer(23, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(128).build())
.layer(24, new BatchNormalization.Builder().build())
.layer(25, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(26, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(27, new DropoutLayer.Builder(0.5).build())

// block 5
.layer(28, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(256).build())
.layer(29, new BatchNormalization.Builder().build())
.layer(30, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(256).build())
.layer(31, new GlobalPoolingLayer.Builder(PoolingType.AVG).build())

//output
.layer(32, new OutputLayer.Builder().nIn(256).nOut(2).lossFunction(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0]))
.backpropType(BackpropType.Standard)
.build();

MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
System.out.println(network.summary());

Visualize the learning process

To see how learning progresses at each epoch, use the features built into DL4J. After connecting the code in this article to the end and executing it, learning will proceed, but at this time, launch your own web browser and http://localhost:9000 in the URL and move to the page. You can check the learning progress graphically on your own PC.

// visualize train process
// URL:http://localhost:9000/train/overview
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);

You can also set how to monitor the learning process. The StatsListener that collects general-purpose information of the model and the ScoreIterationListener that calculates the model accuracy (mainly loss) at specified intervals are often used.

// set Stats Listener, to check confusion matrix for each epoch
network.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(1));

Image input to model input

It’s another breath until learning. Up to this point, we have already created a pipeline for image data input, but we will add a setting that will convert this for model input. DataSetIterator. DataSetIterator is responsible for preparing necessary data for learning each time iterative learning is performed. This time, we will create three DataSetIterators: TRAIN, VAL (verification), and TEST. Of these, as for the image data of TEST, as you can see from the original data folder, unlike other data, image data is not assigned to each class folder, and images are directly inserted in the TEST folder. I will. You can create a folder and copy it like any other data, but this is a good opportunity, so I’ll show you how to enter it without dividing it into folders.

DataSetIterator traindataIter = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, numLabels);
DataSetIterator valdataIter = new RecordReaderDataSetIterator(recordReaderVal, batchSize, 1, numLabels);
//In this example, the test folder does not have the same folder hierarchy as the others, so leave the label blank for the test data.
DataSetIterator testdataIter = new RecordReaderDataSetIterator(recordReaderTest, 1);//1 is a batchsize

Normalization

The final step in training is to normalize the data you enter into the model. Normalization is a method that is often used in statistics, and is a process that eliminates outliers, deviations of the maximum and minimum values for each data, and other things that may confuse the model during learning. Here we are setting a scaler that converts the pixel values of the image to a number between 0 and 1. There are many types of scalers, but in general, scalers tailored for training are also applied to validation and test data. (However, since I am using 0-1 range conversion here, it is a simple one that does not need adjustment, but it is etiquette.)

// Normalization
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(traindataIter);
traindataIter.setPreProcessor(scaler);
valdataIter.setPreProcessor(scaler);
testdataIter.setPreProcessor(scaler);

Model training and validation

Repeat model training epochs times. Just pass the prepared DatasetIterater to a function called fit() in your model. After that, it will automatically repeat data acquisition → reset learning data. (Depending on the model, iterator may need to be repeated within 1epoch, so please pay attention to DL4J’s Example etc.)

For each training, we will also verify. The well-known evaluation index and mixing matrix are calculated in network.evaluate(valdataIter);.

System.out.println("Start training model....");
int i = 0;
while (i <epochs) {
while (traindataIter.hasNext()) {
DataSet trained = traindataIter.next();
// System.out.println(trained.numExamples()); //same as batch size
network.fit(trained);
}
System.out.println("Evaluate model at iteration "+ i + "....");
Evaluation eval = network.evaluate(valdataIter);//use nd4j's Evaluation
System.out.println(eval.stats());
valdataIter.reset(); //return Iterater to the beginning
traindataIter.reset(); //Return Iterater to the beginning
i++;
}

Model testing

Finally, let’s test with data that has not been used for training or testing. Here, I will show you how to input an image by itself and check it yourself without using Evaluation.

/*
* If the folder hierarchy with the original image is the same as the training data folder hierarchy,
* Can be evaluated as above.
* Even if the folders are not organized
* You can rate each image as follows.
*/
System.out.println("Test model....");
while(testdataIter.hasNext()) {
DataSet testData = testdataIter.next();
System.out.println("testing... :"+testData.id());
INDArray input = testData.getFeatures();
INDArray pred = network.output(input);
System.out.println(pred);
int predLabel = Nd4j.argMax(pred).getInt(0); //If there is a label
if(predLabel == 0) {
System.out.println("ABDOMEN"+" with praba "+pred.getDouble(predLabel));
}else {
System.out.println("CHEST"+" with praba "+pred.getDouble(predLabel));
}
}
It's a sequel.
System.out.println("Finish....");

execute

The calculation process on the way can be visualized like this. ![Screenshot from 2020-07-31 12-53-11.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/379410/4937857b-efd2-281f-(a851-dca950e13c10.png)

The evaluation on the way is as follows. epoch16 has done quite well. Some omitted. Evaluate model at iteration 15 …. # of classes: 2 Accuracy: 0.9000 Precision: 0.9167 Recall: 0.9000 F1 Score: 0.8889

The output of the final test is:

Test model…. testing…: [[ 5.7758e-5, 0.9999]] CHEST with praba 0.9999421834945679 testing…: [[ 0.5547, 0.4453]] ABDOMEN with praba 0.5546808838844299 Finish….

It seems that he was able to judge his stomach. It’s still a suspicious model.

Code appearance

It will be as follows.

ChestOrAbd.java



import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.WarpImageTransform;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ChestOrAbd {

 protected static final Logger log = LoggerFactory.getLogger(ChestOrABd.class);

 public static void main(String[] args) {

  long seed = 42;
  final Random RAND_NUM_GEN = new Random(seed);
  final String[] ALLOWED_FORMATS = BaseImageLoader.ALLOWED_FORMATS;
  ParentPathLabelGenerator LABEL_GENERATOR_MAKER = new ParentPathLabelGenerator();
  BalancedPathFilter PATH_FILTER = new BalancedPathFilter(RAND_NUM_GEN, ALLOWED_FORMATS, LABEL_GENERATOR_MAKER);

  int numLabels = 2;// chest or abd
  int height = 64;// image size for train
  int width = 64;// image size for train
  int channels = 3;// image channels(in this case, image type is RGB, so 3 channels)
  int[] inputShape = new int[] {channels, height, width};
  int batchSize = 32;// train data size in 1 epoch
  int epochs = 50;

  System.out.println("Preparing data....");
  // Prepare train
  File trainDir = new File("./Open_I_abd_vs_CXRs/TRAIN/");
  FileSplit trainSplit = new FileSplit(trainDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
  InputSplit train = trainSplit.sample(PATH_FILTER, 1.0)[0];
  // Prepare val
  File valDir = new File("./Open_I_abd_vs_CXRs/VAL/");
  FileSplit valSplit = new FileSplit(valDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
  InputSplit val = valSplit.sample(PATH_FILTER, 1.0)[0];
  // Prepare test
  File testDir = new File("./Open_I_abd_vs_CXRs/TEST/");
  FileSplit testSplit = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
  InputSplit test = testSplit.sample(PATH_FILTER, 1.0)[0];
  
  System.out.println("train data total sample size " + train.length());
  System.out.println("validation total data sample size " + val.length());
  System.out.println("test data total sample size " + test.length());

  System.out.println("Prepare augumentation....");
  ImageTransform flipTransform1 = new FlipImageTransform(new Random(seed));
  ImageTransform flipTransform2 = new FlipImageTransform(new Random(seed));
  ImageTransform warpTransform = new WarpImageTransform(new Random(seed), inputShape[1]/10);
  boolean shuffle = false;
  List<Pair<ImageTransform, Double>> pipeline = Arrays.asList(new Pair<>(flipTransform1, 0.9),
    new Pair<>(flipTransform2, 0.8), new Pair<>(warpTransform, 0.9));
  ImageTransform transform = new PipelineImageTransform(pipeline, shuffle);

  // data reader setup
  ImageRecordReader recordReaderTrain = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
  ImageRecordReader recordReaderVal = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
  /*
   * 今回は配布元のデータ構造に合わせるので、* Test data does not automatically calculate the label of hierarchy.
* (When using, use the same folder hierarchy as the data.)
*/
// ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels);
try {
// recordReaderTrain.initialize(train);// Train without transformations
recordReaderTrain.initialize(train,transform);// Train with transformations
recordReaderVal.initialize(val); //Do not augment the validation data
recordReaderTest.initialize(test);
} catch (IOException e) {
e.printStackTrace();
}

System.out.println("Start construct SimpleCNN model...");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().trainingWorkspaceMode(WorkspaceMode.ENABLED)
.inferenceWorkspaceMode(WorkspaceMode.ENABLED).seed(seed).activation(Activation.IDENTITY)
.weightInit(WeightInit.RELU).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new AdaDelta()).convolutionMode(ConvolutionMode.Same).list()
// block 1
.layer(0,
new ConvolutionLayer.Builder(new int[] {7,7}).name("image_array").nIn(inputShape[0]).nOut(16)
.build())
.layer(1, new BatchNormalization.Builder().build())
.layer(2, new ConvolutionLayer.Builder(new int[] {7, 7 }).nIn(16).nOut(16).build())
.layer(3, new BatchNormalization.Builder().build())
.layer(4, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(6, new DropoutLayer.Builder(0.5).build())

// block 2
.layer(7, new ConvolutionLayer.Builder(new int[] {5, 5 }).nOut(32).build())
.layer(8, new BatchNormalization.Builder().build())
.layer(9, new ConvolutionLayer.Builder(new int[] {5, 5 }).nOut(32).build())
.layer(10, new BatchNormalization.Builder().build())
.layer(11, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(12, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(13, new DropoutLayer.Builder(0.5).build())

// block 3
.layer(14, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(64).build())
.layer(15, new BatchNormalization.Builder().build())
.layer(16, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(64).build())
.layer(17, new BatchNormalization.Builder().build())
.layer(18, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(19, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(20, new DropoutLayer.Builder(0.5).build())

// block 4
.layer(21, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(128).build())
.layer(22, new BatchNormalization.Builder().build())
.layer(23, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(128).build())
.layer(24, new BatchNormalization.Builder().build())
.layer(25, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(26, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] {2, 2 }).build())
.layer(27, new DropoutLayer.Builder(0.5).build())

// block 5
.layer(28, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(256).build())
.layer(29, new BatchNormalization.Builder().build())
.layer(30, new ConvolutionLayer.Builder(new int[] {3, 3 }).nOut(256).build())
.layer(31, new GlobalPoolingLayer.Builder(PoolingType.AVG).build())

//output
.layer(32, new OutputLayer.Builder().nIn(256).nOut(2)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0]))
.backpropType(BackpropType.Standard)
.build();

MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
System.out.println(network.summary());
It's a sequel.
// visualize train process
// URL:http://localhost:9000/train/overview
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
It's a sequel.
// set Stats Listener, to check confusion matrix for each epoch
network.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(1));

/*
* This time there are only 2 classes,
* The teacher label is attached to the image according to the image type (for each folder).
* For example, image 1 (the answer is abdomen): (chest: 0, abdomen: 1).
* In this way, the corresponding person will have a "1".
* This number "1" is the label index.
* Four are set in the argument of DataSetIterator.
* recordReaderTrain, batchSize, 1, numLabels.
* Of these, 1 is the label index.
*/
DataSetIterator traindataIter = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, numLabels);
DataSetIterator valdataIter = new RecordReaderDataSetIterator(recordReaderVal, batchSize, 1, numLabels);
//In this example, the test folder does not have the same folder hierarchy as the others, so leave the label blank for the test data.DataSetIterator testdataIter = new RecordReaderDataSetIterator(recordReaderTest, 1);//1 is a batchsize
//
// Normalization
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(traindataIter);
traindataIter.setPreProcessor(scaler);
valdataIter.setPreProcessor(scaler);
testdataIter.setPreProcessor(scaler);
It's a sequel.
System.out.println("Start training model....");
int i = 0;
while (i <epochs) {
while (traindataIter.hasNext()) {
DataSet trained = traindataIter.next();
// System.out.println(trained.numExamples()); //same as batch size
network.fit(trained);
}
System.out.println("Evaluate model at iteration "+ i + "....");
Evaluation eval = network.evaluate(valdataIter);//use nd4j's Evaluation
System.out.println(eval.stats());
valdataIter.reset(); //return Iterater to the beginning
traindataIter.reset(); //Return Iterater to the beginning
i++;
}

/*
* If the test data folder hierarchy is the same as the others,
* Can be evaluated as above.
* Even if the folders are not organized
* You can rate each image as follows.
*/
System.out.println("Test model....");
while(testdataIter.hasNext()) {
DataSet testData = testdataIter.next();
System.out.println("testing... :"+testData.id());
INDArray input = testData.getFeatures();
INDArray pred = network.output(input);
System.out.println(pred);
int predLabel = Nd4j.argMax(pred).getInt(0); //If there is a label
if(predLabel == 0) {
System.out.println("ABDOMEN"+" with praba "+pred.getDouble(predLabel));
}else {
System.out.println("CHEST"+" with praba "+pred.getDouble(predLabel));
}
}
It's a sequel.
System.out.println("Finish....");
}
}

Impression

In my case, if I could do this, I could start to delusion of various things, such as how to do it and how to do it this way. The next step is transfer learning, how to incorporate layers that have not been used successfully this time, level up to a complicated model (ComputationGraph) (or trial and error for simplification), use of RNN and LSTM, classification problem Challenges other than the above, etc. I would like to do my best so that I can keep up with the topic of the times by keeping pace with the times.

Reference

Reference POM

pom.xml



<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http:/ /maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>com.vis</groupId>
  <artifactId>ChestOrAbd</artifactId>
  <version>0.0.1-SNAPSHOT</version>
  
  <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
<nd4j.version>1.0.0-beta4</nd4j.version>
<dl4j.version>1.0.0-beta4</dl4j.version>
<datavec.version>1.0.0-beta4</datavec.version>
<arbiter.version>1.0.0-beta4</arbiter.version>
<logback.version>1.2.3</logback.version>
<dl4j.spark.version>1.0.0-beta4_spark_2</dl4j.spark.version>
</properties>

<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.0-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
<version>${dl4j.spark.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-ui_2.11</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-codec</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.3.5</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>2.11.0</version>
</dependency>
</dependencies>
</project>