Creating an automatic classifier for chest and abdominal X-ray images for deep learning quick start

Why did you write an article

I held a study session two years ago, but since it is old, I will publish it as a memorandum.

What I tried

Develop an application that can automatically distinguish between chest X-ray images and abdominal X-ray images.

motivation

It is important to classify something into binary. No matter how complicated the task, if it can be systematically organized, the task can be decomposed into parts. If you break down the broken issues into smaller pieces, you will end up with a Yes or No choice. If you subdivide the part and give an answer, you can (maybe) solve a difficult problem just by repeating the selection of Yes or No. Any theme was fine, but if you were looking for a training material to experience the binary classification of images in the medical, this theme was just right. Originally this paper. Hello World Deep Learning in Medical Imaging

What I used

--Laptop (general) --Optional: NVIDIA-GPU (1050Ti this time)

environment

--ubuntu 18.04 (Since it is java, OS does not matter in detail) --maven + dl4j related

data

It is published on the following GitHub link. https://github.com/paras42/Hello_World_Deep_Learning/tree/9921a12c905c00a88898121d5dc538e3b524e520 The image is "Open_I_abd_vs_CXRs.zip". abd stands for Abdomen and CXRs stands for Chest-X-rays. maybe. After downloading, unzip and use. There are 75 images in total, 38 chest X-rays and 37 abdominal X-rays. The folder hierarchy looks like this. It is divided into TEST, TRAIN, and VAL (abbreviation of Validation) folders, and the chest and abdomen image folders are created in each of the TRAIN and VAL folders. Screenshot from 2020-07-27 22-47-38.png The TEST folder contains one chest and one abdomen image, which are not sorted into folders.

work

After decompressing the data, save it in a suitable location. In my case, I put it directly under the Maven project. Screenshot from 2020-07-31 10-48-59.png

Code and commentary

POM.xml is listed at the end of this page. The final code is summarized at the end of this section. Please note that the package you are looking at may differ depending on the version. (Impression that changes frequently ...)

setup

First, prepare the basic parameters and settings for learning.

Random variables are used in various places such as weight calculation and automatic allocation of training data during learning. This is very convenient, but it's annoying if the results change every time. Define a seed to set the same random variable every time.

		long seed = 42;
		final Random RAND_NUM_GEN = new Random(seed);

This time, we are targeting images, so you will need to enter the images. Enter any image! If you do, strange data that happened to be in the folder may be sucked in. To prevent this, set the image format that can be input. Here, by default, you can enter a general-purpose image format.

		final String[] ALLOWED_FORMATS = BaseImageLoader.ALLOWED_FORMATS;

In machine learning (with supervised learning), teacher label data is often created by yourself, but if you set as follows, the folder name will be automatically recognized as the class name and the labels will be sorted automatically. (In this case, for example, chest image 1.png: [0,1](left side is abdomen, right side is chest), index according to the order of the labels. The index can be anything. However, "1" is generally used.)

		ParentPathLabelGenerator LABEL_GENERATOR_MAKER = new ParentPathLabelGenerator();

Next, when sending data to learning, set to input the same number while randomly selecting the data.

		BalancedPathFilter PATH_FILTER = new BalancedPathFilter(RAND_NUM_GEN, ALLOWED_FORMATS, LABEL_GENERATOR_MAKER);

Make the basic settings required to train the model. As you can see in the comments numLabels is the number of labels. This time, the classification is chest and abdomen, so there are two labels. height, width, and channels set the vertical and horizontal matrix of the image (the image you want to predict) to be input to the model and the color channel. inputShape is an array that combines these and becomes the setting value of the input layer of the model. batchSize is the amount of data used in a single training, and the network weights are updated after these data are processed. epochs is the number of learnings. Data for batchSize is learned in one learning, and the weight of the network is updated.

		int numLabels = 2;// chest or abd
		int height = 64;// image size for train
		int width = 64;// image size for train
		int channels = 3;// image channels(in this case, image type is RGB, so 3 channels)
		int[] inputShape = new int[] {channels, height, width};
		int batchSize = 32;// train data size in 1 epoch
		int epochs = 50;

Image data input pipeline

Now that the basic settings are complete, set up how to enter the image. Specify the path to the learning data folder you want to input, and build the FileSplit and InputSplit objects. Originally, these are used to automatically sort the images for training / verification / test, but this time the sorting is completed in the folder, so the data is not sorted by code, training , Validation, and testing each build an input pipeline.

		System.out.println("Preparing data....");
		// Prepare train
		File trainDir = new File("./Open_I_abd_vs_CXRs/TRAIN/");
		FileSplit trainSplit = new FileSplit(trainDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
		InputSplit train = trainSplit.sample(PATH_FILTER, 1.0)[0];//To train everything
		// Prepare val
		File valDir = new File("./Open_I_abd_vs_CXRs/VAL/");
		FileSplit valSplit = new FileSplit(valDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
		InputSplit val = valSplit.sample(PATH_FILTER, 1.0)[0];//To verify everything
		// Prepare test
		File testDir = new File("./Open_I_abd_vs_CXRs/TEST/");
		FileSplit testSplit = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
		InputSplit test = testSplit.sample(PATH_FILTER, 1.0)[0];//Test everything
		
		System.out.println("train data total sample size " + train.length());
		System.out.println("validation total data sample size " + val.length());
		System.out.println("test data total sample size " + test.length());

Augmentation (amplification of pseudo data)

Since the amount of data in this dataset is very small (deep learning requires hundreds of units of data in each class), we will increase the data in a pseudo manner and examine the accuracy of the model. If all goes well with this, it seems that we can develop a model that goes well even with more data! Because you can see. There are many ways to amplify an image, such as flip, rotate, crop, slide position, and transform by affine transformation. The caveat is not to amplify the impossible image. For example, in an ultrasonic image, a pseudo image is created by rotating 180 ° even though there is a back echo. Here, I didn't think so strictly, and used ImageTransform to set random flips and translations of positions. Create some ImageTransforms, and finally put them together in a List and build them as a PipelineImageTransform to complete the pipeline. If the PipelineImageTransform shuffle is True, the pipeline order will be chosen randomly. If False, it will be processed sequentially in List order.

		System.out.println("Prepare augumentation....");
		ImageTransform flipTransform1 = new FlipImageTransform(new Random(seed));
		ImageTransform flipTransform2 = new FlipImageTransform(new Random(seed));
		ImageTransform warpTransform = new WarpImageTransform(new Random(seed), inputShape[1]/10);
		boolean shuffle = false;
		List<Pair<ImageTransform, Double>> pipeline = Arrays.asList(new Pair<>(flipTransform1, 0.9),
				new Pair<>(flipTransform2, 0.8), new Pair<>(warpTransform, 0.9));
		ImageTransform transform = new PipelineImageTransform(pipeline, shuffle);

Linking image input and data amplification

Up to this point, the image input part and data enhancement processing have been set. The rest is to link these. In general, only augment training data. Specify whether to enhance the input of the image as shown in the code below. ImageRecordReader is in charge of managing this image input and augmentation process.

		// data reader setup
		ImageRecordReader recordReaderTrain = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
		ImageRecordReader recordReaderVal = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
		/*
		 *This time, it matches the data structure of the distribution source, so
		 *The test data does not automatically calculate the label of the hierarchy.
		 *(When using, make the data folder hierarchy the same as others.)
		 */
//		ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
		ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels);
		try {
//			recordReaderTrain.initialize(train);// Train without transformations
			recordReaderTrain.initialize(train,transform);// Train with transformations
			recordReaderVal.initialize(val);//No augmentation on validation data
			recordReaderTest.initialize(test);
		} catch (IOException e) {
			e.printStackTrace();
		}

Model building

I wanted to make it a little easier, but since it's a big deal, I'll rent a model Zoo series network called SimpleCNN. The example shown here is not a complete SimpleCNN, but the final output layer is tuned and added for this consideration. I haven't done anything difficult, I just copied and pasted from the SimpleCNN.java code and made the output layer for multi-class classification (since there are 2 classes, binary classification is possible, but here I will use SoftMax as an example). I won't explain it here, but DL4J is based on the MultiLayerNetwork used here, which is a basic and simple CNN concept. Some of the models that are familiar to those who are interested in deep learning are more complex and huge, but such complex and huge models are constructed by combining this MultiLayer Network.

		System.out.println("Start construct SimpleCNN model...");
		MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().trainingWorkspaceMode(WorkspaceMode.ENABLED)
				.inferenceWorkspaceMode(WorkspaceMode.ENABLED).seed(seed).activation(Activation.IDENTITY)
				.weightInit(WeightInit.RELU).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
				.updater(new AdaDelta()).convolutionMode(ConvolutionMode.Same).list()
				// block 1
				.layer(0,
						new ConvolutionLayer.Builder(new int[] { 7, 7 }).name("image_array").nIn(inputShape[0]).nOut(16)
								.build())
				.layer(1, new BatchNormalization.Builder().build())
				.layer(2, new ConvolutionLayer.Builder(new int[] { 7, 7 }).nIn(16).nOut(16).build())
				.layer(3, new BatchNormalization.Builder().build())
				.layer(4, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(6, new DropoutLayer.Builder(0.5).build())

				// block 2
				.layer(7, new ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(32).build())
				.layer(8, new BatchNormalization.Builder().build())
				.layer(9, new ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(32).build())
				.layer(10, new BatchNormalization.Builder().build())
				.layer(11, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(12, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(13, new DropoutLayer.Builder(0.5).build())

				// block 3
				.layer(14, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(64).build())
				.layer(15, new BatchNormalization.Builder().build())
				.layer(16, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(64).build())
				.layer(17, new BatchNormalization.Builder().build())
				.layer(18, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(19, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(20, new DropoutLayer.Builder(0.5).build())

				// block 4
				.layer(21, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(128).build())
				.layer(22, new BatchNormalization.Builder().build())
				.layer(23, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(128).build())
				.layer(24, new BatchNormalization.Builder().build())
				.layer(25, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(26, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(27, new DropoutLayer.Builder(0.5).build())

				// block 5
				.layer(28, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(256).build())
				.layer(29, new BatchNormalization.Builder().build())
				.layer(30, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(256).build())
				.layer(31, new GlobalPoolingLayer.Builder(PoolingType.AVG).build())

				//output
				.layer(32, new OutputLayer.Builder().nIn(256).nOut(2)
						.lossFunction(LossFunctions.LossFunction.MCXENT)
						.weightInit(WeightInit.XAVIER)
						.activation(Activation.SOFTMAX)
						.build())
				.setInputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0]))
				.backpropType(BackpropType.Standard)
				.build();

		MultiLayerNetwork network = new MultiLayerNetwork(conf);
		network.init();
		System.out.println(network.summary());

Visualize the learning process

Use the features built into DL4J to see how learning progresses at each epoch. After connecting and executing the code in this article to the end, learning will proceed. At this time, launch your own web browser and [http: // localhost: 9000](http: // localhost) Please enter: 9000) in the URL to go to the page. You can check the progress of learning graphically on your own PC.

		// visualize train process
		// URL:http://localhost:9000/train/overview
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);

You can also set how to monitor the learning process. The StatsListener, which collects general-purpose information about the model, and the ScoreIterationListener, which calculates the model accuracy (mainly loss) at specified intervals, are often used.

		// set Stats Listener, to check confusion matrix for each epoch
		network.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(1));

Image input to model input

It's time to learn. So far, we have already created a pipeline for image data input, but we will add a setting that will convert it for model input. DataSetIterator. DataSetIterator is responsible for preparing the necessary data for training each time it is repeatedly trained. This time, we will create three DataSetIterators: TRAIN, VAL (verification), and TEST. Of these, regarding the image data of TEST, as you can see from the original data folder, unlike other data, the image data is not assigned to each class folder, and the image is directly contained in the TEST folder. I will. You can create a folder and copy it like any other data, but this is a good opportunity, so I will also show you how to enter it without dividing it into folders.

		DataSetIterator traindataIter = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, numLabels);
		DataSetIterator valdataIter = new RecordReaderDataSetIterator(recordReaderVal, batchSize, 1, numLabels);
		//In this example, the test folder does not have a similar folder hierarchy, so leave the test data unlabeled.
		DataSetIterator testdataIter = new RecordReaderDataSetIterator(recordReaderTest, 1);//1 is a batchsize

Normalization

The final step in training is to normalize the data you enter into your model. Normalization is a method often used in statistics, and is a process to eliminate outliers and deviations between the maximum and minimum values for each data, which may confuse the model in learning. Here we have set up a scaler that converts the pixel values of the image to a number between 0 and 1. There are many types of scalers, but in general, training-tuned scalers are also applied to validation and test data. (However, since we are using 0-1 range conversion here, it is a simple one that does not need to be adjusted, but it is tactically.)

		// Normalization
		DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
		scaler.fit(traindataIter);
		traindataIter.setPreProcessor(scaler);
		valdataIter.setPreProcessor(scaler);
		testdataIter.setPreProcessor(scaler);

Model training and validation

Repeat training the model epochs times. All you have to do is pass the prepared DatasetIterater to the model's fit () function. After that, it will automatically and repeatedly acquire data → reset learning data. (Depending on the model, iterator may need to be repeated within 1 epoch, so please pay attention to the DL4J Example etc.)

Verification will be conducted at the same time for each training. The well-known metrics and mixed matrices are calculated in network.evaluate (valdataIter) ;.

		System.out.println("Start training model....");
		int i = 0;
		while (i < epochs) {
			while (traindataIter.hasNext()) {
				DataSet trained = traindataIter.next();
//				System.out.println(trained.numExamples());//same as batch size
				network.fit(trained);
			}
			System.out.println("Evaluate model at iteration " + i + " ....");
			Evaluation eval = network.evaluate(valdataIter);//use nd4j's Evaluation
			System.out.println(eval.stats());
			valdataIter.reset();//Return Iterator to the beginning
			traindataIter.reset();//Return Iterator to the beginning
			i++;
		}

Model testing

Finally, let's test with data that we haven't used for training or testing. Here, we will show you how to enter an image by itself and check it yourself without using Evaluation.

		/*
		 *If the folder hierarchy with the original image is the same as the folder hierarchy of the training data,
		 *It can be evaluated as above.
		 *Even if the folders are not organized
		 *You can evaluate each image as follows.
		 */
		System.out.println("Test model....");
		while(testdataIter.hasNext()) {
			DataSet testData = testdataIter.next();
			System.out.println("testing... :"+testData.id());
			INDArray input = testData.getFeatures();
			INDArray pred = network.output(input);
			System.out.println(pred);
			int predLabel = Nd4j.argMax(pred).getInt(0);//If there is a label
			if(predLabel == 0) {
				System.out.println("ABDOMEN"+" with praba "+pred.getDouble(predLabel));
			}else {
				System.out.println("CHEST"+" with praba "+pred.getDouble(predLabel));
			}
		}
		
		System.out.println("Finish....");

Run

The calculation process on the way can be visualized like this. Screenshot from 2020-07-31 12-53-11.png

The evaluation on the way is as follows. epoch16 is a pretty good result. Partially omitted. Evaluate model at iteration 15 ....

of classes: 2

Accuracy: 0.9000 Precision: 0.9167 Recall: 0.9000 F1 Score: 0.8889

The output of the final test is as follows:

Test model.... testing... : [[ 5.7758e-5, 0.9999]] CHEST with praba 0.9999421834945679 testing... : [[ 0.5547, 0.4453]] ABDOMEN with praba 0.5546808838844299 Finish....

It seems that the belly was barely judged. It's still a suspicious model.

The appearance of the code

It will be as follows.

ChestOrAbd.java



import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.WarpImageTransform;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ChestOrAbd {

	protected static final Logger log = LoggerFactory.getLogger(ChestOrABd.class);

	public static void main(String[] args) {

		long seed = 42;
		final Random RAND_NUM_GEN = new Random(seed);
		final String[] ALLOWED_FORMATS = BaseImageLoader.ALLOWED_FORMATS;
		ParentPathLabelGenerator LABEL_GENERATOR_MAKER = new ParentPathLabelGenerator();
		BalancedPathFilter PATH_FILTER = new BalancedPathFilter(RAND_NUM_GEN, ALLOWED_FORMATS, LABEL_GENERATOR_MAKER);

		int numLabels = 2;// chest or abd
		int height = 64;// image size for train
		int width = 64;// image size for train
		int channels = 3;// image channels(in this case, image type is RGB, so 3 channels)
		int[] inputShape = new int[] {channels, height, width};
		int batchSize = 32;// train data size in 1 epoch
		int epochs = 50;

		System.out.println("Preparing data....");
		// Prepare train
		File trainDir = new File("./Open_I_abd_vs_CXRs/TRAIN/");
		FileSplit trainSplit = new FileSplit(trainDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
		InputSplit train = trainSplit.sample(PATH_FILTER, 1.0)[0];
		// Prepare val
		File valDir = new File("./Open_I_abd_vs_CXRs/VAL/");
		FileSplit valSplit = new FileSplit(valDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
		InputSplit val = valSplit.sample(PATH_FILTER, 1.0)[0];
		// Prepare test
		File testDir = new File("./Open_I_abd_vs_CXRs/TEST/");
		FileSplit testSplit = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, RAND_NUM_GEN);
		InputSplit test = testSplit.sample(PATH_FILTER, 1.0)[0];
		
		System.out.println("train data total sample size " + train.length());
		System.out.println("validation total data sample size " + val.length());
		System.out.println("test data total sample size " + test.length());

		System.out.println("Prepare augumentation....");
		ImageTransform flipTransform1 = new FlipImageTransform(new Random(seed));
		ImageTransform flipTransform2 = new FlipImageTransform(new Random(seed));
		ImageTransform warpTransform = new WarpImageTransform(new Random(seed), inputShape[1]/10);
		boolean shuffle = false;
		List<Pair<ImageTransform, Double>> pipeline = Arrays.asList(new Pair<>(flipTransform1, 0.9),
				new Pair<>(flipTransform2, 0.8), new Pair<>(warpTransform, 0.9));
		ImageTransform transform = new PipelineImageTransform(pipeline, shuffle);

		// data reader setup
		ImageRecordReader recordReaderTrain = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
		ImageRecordReader recordReaderVal = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
		/*
		 *This time, it matches the data structure of the distribution source, so
		 *The test data does not automatically calculate the label of the hierarchy.
		 *(When using, make the data folder hierarchy the same as others.)
		 */
//		ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels, LABEL_GENERATOR_MAKER);
		ImageRecordReader recordReaderTest = new ImageRecordReader(height, width, channels);
		try {
//			recordReaderTrain.initialize(train);// Train without transformations
			recordReaderTrain.initialize(train,transform);// Train with transformations
			recordReaderVal.initialize(val);//No augmentation on validation data
			recordReaderTest.initialize(test);
		} catch (IOException e) {
			e.printStackTrace();
		}

		System.out.println("Start construct SimpleCNN model...");
		MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().trainingWorkspaceMode(WorkspaceMode.ENABLED)
				.inferenceWorkspaceMode(WorkspaceMode.ENABLED).seed(seed).activation(Activation.IDENTITY)
				.weightInit(WeightInit.RELU).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
				.updater(new AdaDelta()).convolutionMode(ConvolutionMode.Same).list()
				// block 1
				.layer(0,
						new ConvolutionLayer.Builder(new int[] { 7, 7 }).name("image_array").nIn(inputShape[0]).nOut(16)
								.build())
				.layer(1, new BatchNormalization.Builder().build())
				.layer(2, new ConvolutionLayer.Builder(new int[] { 7, 7 }).nIn(16).nOut(16).build())
				.layer(3, new BatchNormalization.Builder().build())
				.layer(4, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(6, new DropoutLayer.Builder(0.5).build())

				// block 2
				.layer(7, new ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(32).build())
				.layer(8, new BatchNormalization.Builder().build())
				.layer(9, new ConvolutionLayer.Builder(new int[] { 5, 5 }).nOut(32).build())
				.layer(10, new BatchNormalization.Builder().build())
				.layer(11, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(12, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(13, new DropoutLayer.Builder(0.5).build())

				// block 3
				.layer(14, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(64).build())
				.layer(15, new BatchNormalization.Builder().build())
				.layer(16, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(64).build())
				.layer(17, new BatchNormalization.Builder().build())
				.layer(18, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(19, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(20, new DropoutLayer.Builder(0.5).build())

				// block 4
				.layer(21, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(128).build())
				.layer(22, new BatchNormalization.Builder().build())
				.layer(23, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(128).build())
				.layer(24, new BatchNormalization.Builder().build())
				.layer(25, new ActivationLayer.Builder().activation(Activation.RELU).build())
				.layer(26, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[] { 2, 2 }).build())
				.layer(27, new DropoutLayer.Builder(0.5).build())

				// block 5
				.layer(28, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(256).build())
				.layer(29, new BatchNormalization.Builder().build())
				.layer(30, new ConvolutionLayer.Builder(new int[] { 3, 3 }).nOut(256).build())
				.layer(31, new GlobalPoolingLayer.Builder(PoolingType.AVG).build())

				//output
				.layer(32, new OutputLayer.Builder().nIn(256).nOut(2)
						.lossFunction(LossFunctions.LossFunction.MCXENT)
						.weightInit(WeightInit.XAVIER)
						.activation(Activation.SOFTMAX)
						.build())
				.setInputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0]))
				.backpropType(BackpropType.Standard)
				.build();

		MultiLayerNetwork network = new MultiLayerNetwork(conf);
		network.init();
		System.out.println(network.summary());
		
		// visualize train process
		// URL:http://localhost:9000/train/overview
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		
		// set Stats Listener, to check confusion matrix for each epoch
		network.setListeners(new StatsListener(statsStorage), new ScoreIterationListener(1));

		/*
		 *There are only 2 classes this time,
		 *Teacher labels attach teacher labels to images depending on the type of image (per folder).
		 *For example, image 1 (the answer is abdomen):(Chest:0,abdomen:1)is.
		 *In this way, "1" is added to the corresponding person.
		 *The number "1" is the label index.
		 *Four are set in the argument of DataSetIterator.
		 * recordReaderTrain, batchSize, 1,numLabels.
		 *Of these, the 1 part is the label index.
		 */
		DataSetIterator traindataIter = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, numLabels);
		DataSetIterator valdataIter = new RecordReaderDataSetIterator(recordReaderVal, batchSize, 1, numLabels);
		//In this example, the test folder does not have a similar folder hierarchy, so leave the test data unlabeled.
		DataSetIterator testdataIter = new RecordReaderDataSetIterator(recordReaderTest, 1);//1 is a batchsize
//
		// Normalization
		DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
		scaler.fit(traindataIter);
		traindataIter.setPreProcessor(scaler);
		valdataIter.setPreProcessor(scaler);
		testdataIter.setPreProcessor(scaler);
		
		System.out.println("Start training model....");
		int i = 0;
		while (i < epochs) {
			while (traindataIter.hasNext()) {
				DataSet trained = traindataIter.next();
//				System.out.println(trained.numExamples());//same as batch size
				network.fit(trained);
			}
			System.out.println("Evaluate model at iteration " + i + " ....");
			Evaluation eval = network.evaluate(valdataIter);//use nd4j's Evaluation
			System.out.println(eval.stats());
			valdataIter.reset();//Return Iterator to the beginning
			traindataIter.reset();//Return Iterator to the beginning
			i++;
		}

		/*
		 *If the folder hierarchy of test data is the same as the others,
		 *It can be evaluated as above.
		 *Even if the folders are not organized
		 *You can evaluate each image as follows.
		 */
		System.out.println("Test model....");
		while(testdataIter.hasNext()) {
			DataSet testData = testdataIter.next();
			System.out.println("testing... :"+testData.id());
			INDArray input = testData.getFeatures();
			INDArray pred = network.output(input);
			System.out.println(pred);
			int predLabel = Nd4j.argMax(pred).getInt(0);//If there is a label
			if(predLabel == 0) {
				System.out.println("ABDOMEN"+" with praba "+pred.getDouble(predLabel));
			}else {
				System.out.println("CHEST"+" with praba "+pred.getDouble(predLabel));
			}
		}
		
		System.out.println("Finish....");
	}
}

Impressions

In my case, if I could do this, I could start fantasizing about various things, such as how about that and how to do it this way. The next steps are transfer learning, how to incorporate layers that have not been used well this time, leveling up to a complex model (ComputationGraph) (or trial and error for simplification), using RNN and LSTM, and classification problems. There are challenges such as challenges other than. I would like to keep up with the times and do my best to keep up with these topics.

Reference

Reference POM

pom.xml



<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>com.vis</groupId>
  <artifactId>ChestOrAbd</artifactId>
  <version>0.0.1-SNAPSHOT</version>
  
  	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<java.version>1.8</java.version>
		<nd4j.version>1.0.0-beta4</nd4j.version>
		<dl4j.version>1.0.0-beta4</dl4j.version>
		<datavec.version>1.0.0-beta4</datavec.version>
		<arbiter.version>1.0.0-beta4</arbiter.version>
		<logback.version>1.2.3</logback.version>
		<dl4j.spark.version>1.0.0-beta4_spark_2</dl4j.spark.version>
	</properties>

	<dependencies>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native</artifactId>
			<version>${nd4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-cuda-10.0-platform</artifactId>
			<version>${nd4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>dl4j-spark_2.11</artifactId>
			<version>${dl4j.spark.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-nlp</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-zoo</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-deeplearning4j</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-ui_2.11</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.datavec</groupId>
			<artifactId>datavec-data-codec</artifactId>
			<version>${datavec.version}</version>
		</dependency>
		<dependency>
			<groupId>org.apache.httpcomponents</groupId>
			<artifactId>httpclient</artifactId>
			<version>4.3.5</version>
		</dependency>
		<dependency>
			<groupId>ch.qos.logback</groupId>
			<artifactId>logback-classic</artifactId>
			<version>${logback.version}</version>
		</dependency>
		<dependency>
			<groupId>com.fasterxml.jackson.core</groupId>
			<artifactId>jackson-annotations</artifactId>
			<version>2.11.0</version>
		</dependency>
	</dependencies>
</project>

Recommended Posts

Creating an automatic classifier for chest and abdominal X-ray images for deep learning quick start