[JAVA] LSTM supports a huge amount of RNA-Seq features

Why did you write an article

Surprisingly, there was a person who read my disorganized article, so there is only one earth → all human beings brothers → a sense of fellowship has sprung up on my own, and I write occasionally. (Also, it's good for me.)

What I tried

--Use LSTM for model layers --Predict cancer genes (5 types) by multi-class classification using RNA-Seq data (801 x 20531)

What was used

--Laptop (general) --Optional: NVIDIA-GPU (1050Ti this time) (AMD is no good, because dl4j relies on CUDA)

environment

--ubuntu 18.04 (Since it is java, OS does not matter in detail) --maven + dl4j related

data

This data is a part of the data randomly extracted from the dataset acquired by the cancer genome atlas pan-cancer analysis project using a high-performance RNA analyzer called HiSeq from Illumina. For more information, https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3919969/


RNA-Seq is a side-by-side arrangement of how much a gene is likely to be expressed with respect to the reference sequence, quantified for each gene. The reference sequence is a gene pattern that a researcher decides when the gene pattern is known to some extent. Sequence patterns are attracting attention because they may explain the properties and attributes of individuals. In particular, the search for RNA in cancer cells has been carried out for a long time in drug discovery research. According to a person familiar with the matter, even if the human genome is elucidated, there is still a lot of work to be done.


What to classify

The following cancer types are examined in this study, and the sample data includes five (★) of these.

These five types of cancer are classified.

Project preparation

Create a maven project and create an arbitrary class file.

CSV preparation

//loading csv
File dataFile = new File("./TCGA-PANCAN-HiSeq-801x20531/data.csv"); 
File labelFile = new File("./TCGA-PANCAN-HiSeq-801x20531/labels.csv"); 

Data extraction

int numClasses = 5;     //5 classes
int batchSize = 801;    //samples total

//Get the contents from the data body first
RecordReader reader = new CSVRecordReader(1,',');//skip header
try {
	reader.initialize(new FileSplit(dataFile));
} catch (IOException | InterruptedException e) {
	e.printStackTrace();
}

double[][] dataObj = new double[batchSize][];
int itr = 0;
while(reader.hasNext()) {
	List<Writable> row = reader.next();
	double scalers[] = new double[row.size()-1];
	for(int i = 0; i < row.size()-1; i++) {
		if(i == 0) {//skip subject
			continue;
		}
		double scaler = Double.parseDouble(new ConvertToString().map(row.get(i)).toString());
		scalers[i] = scaler;
	}
	dataObj[itr] = scalers;
	itr++;
}
System.out.println("Data samples "+ +dataObj.length);//801

//Read label
//Also convert for multi-label
//label
try {
	reader = new CSVRecordReader(1,',');//skip header
	reader.initialize(new FileSplit(labelFile));
} catch (IOException | InterruptedException e) {
	e.printStackTrace();
}
double[][] labels = new double[batchSize][];
itr = 0;
while(reader.hasNext()) {
	List<Writable> row = reader.next();
	double scalers[] = null;
	for(int i = 0; i < row.size(); i++) {
		if(i == 0) {//skip subject
			continue;
		}
		// Class
		if(i == 1) {
			String classname = new ConvertToString().map(row.get(i)).toString();
			switch(classname) {
				case "BRCA":
					scalers = new double[]{1,0,0,0,0};
					break;
				case "PRAD":
					scalers = new double[]{0,1,0,0,0};
					break;
				case "LUAD":
					scalers = new double[]{0,0,1,0,0};
					break;
				case "KIRC":
					scalers = new double[]{0,0,0,1,0};
					break;
				case "COAD":
					scalers = new double[]{0,0,0,0,1};
					break;
				default:
					break;
			}
			labels[itr] = scalers;
			itr++;
		}
	}
}
System.out.println("LABEL : "+labels.length);//801

Once in INDArray to turn the data into a DataSet object

//Create a DataSet
INDArray dataArray = Nd4j.create(dataObj,'c');
System.out.println(dataArray.shapeInfoToString());
INDArray labelArray = Nd4j.create(labels,'c');
System.out.println(labelArray.shapeInfoToString());

//Rank: 2,Offset: 0
// Order: c Shape: [801,20531],  stride: [20531,1]
//Rank: 2,Offset: 0
// Order: c Shape: [801,5],  stride: [5,1]

To DataSet

DataSet dataset = new DataSet(dataArray, labelArray);
SplitTestAndTrain sp = dataset.splitTestAndTrain(600, new Random(42L));//600 train, 201 test
DataSet train = sp.getTrain();
DataSet test = sp.getTest();
System.out.println(train.labelCounts());
System.out.println(test.labelCounts());

//{0=220.0, 1=105.0, 2=104.0, 3=109.0, 4=62.0}
//{0=80.0, 1=31.0, 2=37.0, 3=37.0, 4=16.0}

Model construction / training / evaluation

//MODEL TRAIN AND EVALUATION
int numInput = 20531;
int numOutput = numClasses;
int hiddenNode = 500;//Powerless
int numEpochs = 50;
		
MultiLayerConfiguration LSTMConf = new NeuralNetConfiguration.Builder()
				.seed(123)
				.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
				.weightInit(WeightInit.XAVIER)
				.updater(new Adam(0.001))
				.list()
				.layer(0,new LSTM.Builder()
							.nIn(numInput)
							.nOut(hiddenNode)
							.activation(Activation.RELU)
							.build())
				.layer(1,new LSTM.Builder()
						.nIn(hiddenNode)
						.nOut(hiddenNode)
						.activation(Activation.RELU)
						.build())
				.layer(2,new LSTM.Builder()
						.nIn(hiddenNode)
						.nOut(hiddenNode)
						.activation(Activation.RELU)
						.build())
				.layer(3,new RnnOutputLayer.Builder()
						.nIn(hiddenNode)
						.nOut(numOutput)
						.activation(Activation.SOFTMAX)
						.lossFunction(LossFunction.MCXENT)//multi class cross entropy
						.build())
				.pretrain(false)
				.backprop(true)
				.build();
		
MultiLayerNetwork model = new MultiLayerNetwork(LSTMConf);
model.init();
System.out.println("TRAIN START...");
for(int i=0;i<numEpochs;i++) {
	model.fit(train);
}
		
System.out.println("EVALUATION START...");
Evaluation eval = new Evaluation(5);
for(DataSet row :test.asList()) {
	INDArray testdata = row.getFeatures();
	INDArray pred = model.output(testdata);
	eval.eval(row.getLabels(), pred);
}
System.out.println(eval.stats());

Output of evaluation result

TRAIN START...
EVALUATION START...

Predictions labeled as 0 classified by model as 0: 80 times
Predictions labeled as 1 classified by model as 1: 31 times
Predictions labeled as 2 classified by model as 0: 3 times
Predictions labeled as 2 classified by model as 2: 34 times
Predictions labeled as 3 classified by model as 2: 1 times
Predictions labeled as 3 classified by model as 3: 36 times
Predictions labeled as 4 classified by model as 2: 4 times
Predictions labeled as 4 classified by model as 4: 12 times


==========================Scores========================================
 # of classes:    5
 Accuracy:        0.9602
 Precision:       0.9671
 Recall:          0.9284
 F1 Score:        0.9440
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
========================================================================

Impressions

Logistic regression and SVM (linear) are used when the features exceed several hundreds, but I thought it might be better to use LSTM. It doesn't take as long to study as I expected. I just tried it, but I'm glad I tried it.

Reference material

Attachment (pom.xml)

pom.xml


<project xmlns="http://maven.apache.org/POM/4.0.0"
	xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<groupId>com.vis</groupId>
	<artifactId>CancerGenomeTest</artifactId>
	<version>0.0.1-SNAPSHOT</version>

	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<java.version>1.8</java.version>
		<nd4j.version>1.0.0-alpha</nd4j.version>
		<dl4j.version>1.0.0-alpha</dl4j.version>
		<datavec.version>1.0.0-alpha</datavec.version>
		<arbiter.version>1.0.0-alpha</arbiter.version>
		<logback.version>1.2.3</logback.version>
		<dl4j.spark.version>1.0.0-alpha_spark_2</dl4j.spark.version>
	</properties>

	<dependencies>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native</artifactId>
			<version>${nd4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>dl4j-spark_2.11</artifactId>
			<version>${dl4j.spark.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-nlp</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-zoo</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-deeplearning4j</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-ui_2.11</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.datavec</groupId>
			<artifactId>datavec-data-codec</artifactId>
			<version>${datavec.version}</version>
		</dependency>
		<dependency>
			<groupId>org.apache.httpcomponents</groupId>
			<artifactId>httpclient</artifactId>
			<version>4.3.5</version>
		</dependency>
		<dependency>
			<groupId>ch.qos.logback</groupId>
			<artifactId>logback-classic</artifactId>
			<version>${logback.version}</version>
		</dependency>
		<dependency>
			<groupId>com.fasterxml.jackson.core</groupId>
			<artifactId>jackson-annotations</artifactId>
			<version>2.11.0</version>
		</dependency>
	</dependencies>
</project>

Recommended Posts

LSTM supports a huge amount of RNA-Seq features
A brief summary of Bootstrap features for beginners