[JAVA] LSTM unterstützt eine Vielzahl von RNA-Seq-Funktionen

Warum hast du einen Artikel geschrieben?

Überraschenderweise gab es eine Person, die meinen unorganisierten Artikel las, also gibt es nur eine Erde → alle menschlichen Brüder → ein Gefühl der Gemeinschaft ist von selbst entstanden, und ich schreibe gelegentlich. (Außerdem ist es gut für mich.)

Was ich versucht habe

Was wurde verwendet

Umgebung

--ubuntu 18.04 (Da es sich um Java handelt, spielt das Betriebssystem keine Rolle im Detail) --maven + dl4j verwandt

Daten

Diese Daten sind Teil der Daten, die zufällig aus dem Datensatz extrahiert wurden, der vom Pan-Cancer-Analyseprojekt des Krebsgenomatlas unter Verwendung eines Hochleistungs-RNA-Analysators namens HiSeq of Illumina erfasst wurde. Weitere Informationen finden Sie unter https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3919969/.


RNA-Seq ist eine nebeneinander angeordnete Anordnung, die quantifiziert, wie viel ein Gen in Bezug auf die Referenzsequenz für jedes Gen wahrscheinlich exprimiert wird. Die Referenzsequenz ist ein Genmuster, das der Forscher entscheidet, wenn das Genmuster bis zu einem gewissen Grad bekannt ist. Das Sequenzmuster zieht die Aufmerksamkeit auf sich, weil es die Eigenschaften und Attribute von Individuen erklären kann. Insbesondere die Suche nach RNA in Krebszellen wurde in der Arzneimittelforschung seit langem durchgeführt. Laut einer mit der Angelegenheit vertrauten Person gibt es noch viel zu tun, selbst wenn das menschliche Genom aufgeklärt ist.


Was zu klassifizieren

Die folgenden Krebsarten werden in dieser Studie untersucht, und die Probendaten umfassen fünf (★) davon.

Diese fünf Krebsarten werden klassifiziert.

Projektvorbereitung

Erstellen Sie ein Maven-Projekt und eine beliebige Klassendatei.

CSV-Vorbereitung

//Laden von csv
File dataFile = new File("./TCGA-PANCAN-HiSeq-801x20531/data.csv"); 
File labelFile = new File("./TCGA-PANCAN-HiSeq-801x20531/labels.csv"); 

Datenextraktion

int numClasses = 5;     //5 classes
int batchSize = 801;    //samples total

//Holen Sie sich zuerst den Inhalt aus dem Datenkörper
RecordReader reader = new CSVRecordReader(1,',');//skip header
try {
	reader.initialize(new FileSplit(dataFile));
} catch (IOException | InterruptedException e) {
	e.printStackTrace();
}

double[][] dataObj = new double[batchSize][];
int itr = 0;
while(reader.hasNext()) {
	List<Writable> row = reader.next();
	double scalers[] = new double[row.size()-1];
	for(int i = 0; i < row.size()-1; i++) {
		if(i == 0) {//skip subject
			continue;
		}
		double scaler = Double.parseDouble(new ConvertToString().map(row.get(i)).toString());
		scalers[i] = scaler;
	}
	dataObj[itr] = scalers;
	itr++;
}
System.out.println("Data samples "+ +dataObj.length);//801

//Etikett lesen
//Konvertieren Sie auch für Multi-Label
//label
try {
	reader = new CSVRecordReader(1,',');//skip header
	reader.initialize(new FileSplit(labelFile));
} catch (IOException | InterruptedException e) {
	e.printStackTrace();
}
double[][] labels = new double[batchSize][];
itr = 0;
while(reader.hasNext()) {
	List<Writable> row = reader.next();
	double scalers[] = null;
	for(int i = 0; i < row.size(); i++) {
		if(i == 0) {//skip subject
			continue;
		}
		// Class
		if(i == 1) {
			String classname = new ConvertToString().map(row.get(i)).toString();
			switch(classname) {
				case "BRCA":
					scalers = new double[]{1,0,0,0,0};
					break;
				case "PRAD":
					scalers = new double[]{0,1,0,0,0};
					break;
				case "LUAD":
					scalers = new double[]{0,0,1,0,0};
					break;
				case "KIRC":
					scalers = new double[]{0,0,0,1,0};
					break;
				case "COAD":
					scalers = new double[]{0,0,0,0,1};
					break;
				default:
					break;
			}
			labels[itr] = scalers;
			itr++;
		}
	}
}
System.out.println("LABEL : "+labels.length);//801

Einmal in INDArray, um die Daten in ein DataSet-Objekt umzuwandeln

//Erstellen Sie ein DataSet
INDArray dataArray = Nd4j.create(dataObj,'c');
System.out.println(dataArray.shapeInfoToString());
INDArray labelArray = Nd4j.create(labels,'c');
System.out.println(labelArray.shapeInfoToString());

//Rank: 2,Offset: 0
// Order: c Shape: [801,20531],  stride: [20531,1]
//Rank: 2,Offset: 0
// Order: c Shape: [801,5],  stride: [5,1]

Zum DataSet

DataSet dataset = new DataSet(dataArray, labelArray);
SplitTestAndTrain sp = dataset.splitTestAndTrain(600, new Random(42L));//600 train, 201 test
DataSet train = sp.getTrain();
DataSet test = sp.getTest();
System.out.println(train.labelCounts());
System.out.println(test.labelCounts());

//{0=220.0, 1=105.0, 2=104.0, 3=109.0, 4=62.0}
//{0=80.0, 1=31.0, 2=37.0, 3=37.0, 4=16.0}

Modellbau / Schulung / Bewertung

//MODEL TRAIN AND EVALUATION
int numInput = 20531;
int numOutput = numClasses;
int hiddenNode = 500;//Machtlos
int numEpochs = 50;
		
MultiLayerConfiguration LSTMConf = new NeuralNetConfiguration.Builder()
				.seed(123)
				.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
				.weightInit(WeightInit.XAVIER)
				.updater(new Adam(0.001))
				.list()
				.layer(0,new LSTM.Builder()
							.nIn(numInput)
							.nOut(hiddenNode)
							.activation(Activation.RELU)
							.build())
				.layer(1,new LSTM.Builder()
						.nIn(hiddenNode)
						.nOut(hiddenNode)
						.activation(Activation.RELU)
						.build())
				.layer(2,new LSTM.Builder()
						.nIn(hiddenNode)
						.nOut(hiddenNode)
						.activation(Activation.RELU)
						.build())
				.layer(3,new RnnOutputLayer.Builder()
						.nIn(hiddenNode)
						.nOut(numOutput)
						.activation(Activation.SOFTMAX)
						.lossFunction(LossFunction.MCXENT)//multi class cross entropy
						.build())
				.pretrain(false)
				.backprop(true)
				.build();
		
MultiLayerNetwork model = new MultiLayerNetwork(LSTMConf);
model.init();
System.out.println("TRAIN START...");
for(int i=0;i<numEpochs;i++) {
	model.fit(train);
}
		
System.out.println("EVALUATION START...");
Evaluation eval = new Evaluation(5);
for(DataSet row :test.asList()) {
	INDArray testdata = row.getFeatures();
	INDArray pred = model.output(testdata);
	eval.eval(row.getLabels(), pred);
}
System.out.println(eval.stats());

Ausgabe des Bewertungsergebnisses

TRAIN START...
EVALUATION START...

Predictions labeled as 0 classified by model as 0: 80 times
Predictions labeled as 1 classified by model as 1: 31 times
Predictions labeled as 2 classified by model as 0: 3 times
Predictions labeled as 2 classified by model as 2: 34 times
Predictions labeled as 3 classified by model as 2: 1 times
Predictions labeled as 3 classified by model as 3: 36 times
Predictions labeled as 4 classified by model as 2: 4 times
Predictions labeled as 4 classified by model as 4: 12 times


==========================Scores========================================
 # of classes:    5
 Accuracy:        0.9602
 Precision:       0.9671
 Recall:          0.9284
 F1 Score:        0.9440
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
========================================================================

Impressionen

Wenn die Merkmalsmenge mehrere Hundert überschreitet, werden logistische Regression und SVM (linear) verwendet, aber ich dachte, dass es möglicherweise besser ist, LSTM zu verwenden. Das Studium dauert nicht so lange wie erwartet. Ich habe es gerade versucht, aber ich bin froh, dass ich es versucht habe.

Referenzmaterial

Anhang (pom.xml)

pom.xml


<project xmlns="http://maven.apache.org/POM/4.0.0"
	xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<groupId>com.vis</groupId>
	<artifactId>CancerGenomeTest</artifactId>
	<version>0.0.1-SNAPSHOT</version>

	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<java.version>1.8</java.version>
		<nd4j.version>1.0.0-alpha</nd4j.version>
		<dl4j.version>1.0.0-alpha</dl4j.version>
		<datavec.version>1.0.0-alpha</datavec.version>
		<arbiter.version>1.0.0-alpha</arbiter.version>
		<logback.version>1.2.3</logback.version>
		<dl4j.spark.version>1.0.0-alpha_spark_2</dl4j.spark.version>
	</properties>

	<dependencies>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native</artifactId>
			<version>${nd4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>dl4j-spark_2.11</artifactId>
			<version>${dl4j.spark.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-nlp</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-zoo</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-deeplearning4j</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-ui_2.11</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.datavec</groupId>
			<artifactId>datavec-data-codec</artifactId>
			<version>${datavec.version}</version>
		</dependency>
		<dependency>
			<groupId>org.apache.httpcomponents</groupId>
			<artifactId>httpclient</artifactId>
			<version>4.3.5</version>
		</dependency>
		<dependency>
			<groupId>ch.qos.logback</groupId>
			<artifactId>logback-classic</artifactId>
			<version>${logback.version}</version>
		</dependency>
		<dependency>
			<groupId>com.fasterxml.jackson.core</groupId>
			<artifactId>jackson-annotations</artifactId>
			<version>2.11.0</version>
		</dependency>
	</dependencies>
</project>

Recommended Posts

LSTM unterstützt eine Vielzahl von RNA-Seq-Funktionen
Zusammenfassung der einfachen Funktionen von Bootstrap für Anfänger