[JAVA] LSTM prend en charge une grande quantité de fonctionnalités RNA-Seq

Pourquoi avez-vous écrit un article

Étonnamment, il y avait une personne qui a lu mon article désorganisé, donc il n'y a qu'une seule terre → tous les êtres humains frères → un sentiment de collègue bourgeoise tout seul, et j'écris de temps en temps. (Aussi, c'est bon pour moi.)

Ce que j'ai essayé

--Utilisez LSTM pour les couches de modèle

Ce qui a été utilisé

environnement

--ubuntu 18.04 (Puisqu'il s'agit de java, le système d'exploitation n'a pas d'importance en détail) --maven + dl4j liés

Les données

Ces données font partie des données extraites au hasard de l'ensemble de données acquises par le projet d'analyse pan-cancer de l'atlas du génome du cancer à l'aide d'un analyseur d'ARN haute performance appelé HiSeq d'Illumina. Pour plus d'informations, https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3919969/


RNA-Seq est un arrangement côte à côte qui quantifie la quantité d'expressions d'un gène par rapport à la séquence de référence pour chaque gène. La séquence de référence est un modèle de gène que le chercheur décide lorsque le modèle de gène est connu dans une certaine mesure. Le modèle de séquence attire l'attention car il peut expliquer les propriétés et les attributs des individus. En particulier, la recherche d'ARN dans les cellules cancéreuses est menée depuis longtemps dans la recherche de découverte de médicaments. Selon une personne familiarisée avec le sujet, même si le génome humain est élucidé, il reste encore beaucoup de travail à faire.


Que classer

Les types de cancer suivants sont examinés dans cette étude, et l'échantillon de données comprend cinq (★) d'entre eux.

Ces cinq types de cancer sont classés.

Préparation du projet

Créez un projet maven et créez un fichier de classe arbitraire.

Préparation CSV

//Chargement de csv
File dataFile = new File("./TCGA-PANCAN-HiSeq-801x20531/data.csv"); 
File labelFile = new File("./TCGA-PANCAN-HiSeq-801x20531/labels.csv"); 

Extraction de données

int numClasses = 5;     //5 classes
int batchSize = 801;    //samples total

//Obtenez d'abord le contenu du corps de données
RecordReader reader = new CSVRecordReader(1,',');//skip header
try {
	reader.initialize(new FileSplit(dataFile));
} catch (IOException | InterruptedException e) {
	e.printStackTrace();
}

double[][] dataObj = new double[batchSize][];
int itr = 0;
while(reader.hasNext()) {
	List<Writable> row = reader.next();
	double scalers[] = new double[row.size()-1];
	for(int i = 0; i < row.size()-1; i++) {
		if(i == 0) {//skip subject
			continue;
		}
		double scaler = Double.parseDouble(new ConvertToString().map(row.get(i)).toString());
		scalers[i] = scaler;
	}
	dataObj[itr] = scalers;
	itr++;
}
System.out.println("Data samples "+ +dataObj.length);//801

//Lire l'étiquette
//Convertissez également pour multi-étiquettes
//label
try {
	reader = new CSVRecordReader(1,',');//skip header
	reader.initialize(new FileSplit(labelFile));
} catch (IOException | InterruptedException e) {
	e.printStackTrace();
}
double[][] labels = new double[batchSize][];
itr = 0;
while(reader.hasNext()) {
	List<Writable> row = reader.next();
	double scalers[] = null;
	for(int i = 0; i < row.size(); i++) {
		if(i == 0) {//skip subject
			continue;
		}
		// Class
		if(i == 1) {
			String classname = new ConvertToString().map(row.get(i)).toString();
			switch(classname) {
				case "BRCA":
					scalers = new double[]{1,0,0,0,0};
					break;
				case "PRAD":
					scalers = new double[]{0,1,0,0,0};
					break;
				case "LUAD":
					scalers = new double[]{0,0,1,0,0};
					break;
				case "KIRC":
					scalers = new double[]{0,0,0,1,0};
					break;
				case "COAD":
					scalers = new double[]{0,0,0,0,1};
					break;
				default:
					break;
			}
			labels[itr] = scalers;
			itr++;
		}
	}
}
System.out.println("LABEL : "+labels.length);//801

Une fois dans INDArray pour transformer les données en un objet DataSet

//Créer un DataSet
INDArray dataArray = Nd4j.create(dataObj,'c');
System.out.println(dataArray.shapeInfoToString());
INDArray labelArray = Nd4j.create(labels,'c');
System.out.println(labelArray.shapeInfoToString());

//Rank: 2,Offset: 0
// Order: c Shape: [801,20531],  stride: [20531,1]
//Rank: 2,Offset: 0
// Order: c Shape: [801,5],  stride: [5,1]

Vers DataSet

DataSet dataset = new DataSet(dataArray, labelArray);
SplitTestAndTrain sp = dataset.splitTestAndTrain(600, new Random(42L));//600 train, 201 test
DataSet train = sp.getTrain();
DataSet test = sp.getTest();
System.out.println(train.labelCounts());
System.out.println(test.labelCounts());

//{0=220.0, 1=105.0, 2=104.0, 3=109.0, 4=62.0}
//{0=80.0, 1=31.0, 2=37.0, 3=37.0, 4=16.0}

Construction / formation / évaluation du modèle

//MODEL TRAIN AND EVALUATION
int numInput = 20531;
int numOutput = numClasses;
int hiddenNode = 500;//Impuissant
int numEpochs = 50;
		
MultiLayerConfiguration LSTMConf = new NeuralNetConfiguration.Builder()
				.seed(123)
				.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
				.weightInit(WeightInit.XAVIER)
				.updater(new Adam(0.001))
				.list()
				.layer(0,new LSTM.Builder()
							.nIn(numInput)
							.nOut(hiddenNode)
							.activation(Activation.RELU)
							.build())
				.layer(1,new LSTM.Builder()
						.nIn(hiddenNode)
						.nOut(hiddenNode)
						.activation(Activation.RELU)
						.build())
				.layer(2,new LSTM.Builder()
						.nIn(hiddenNode)
						.nOut(hiddenNode)
						.activation(Activation.RELU)
						.build())
				.layer(3,new RnnOutputLayer.Builder()
						.nIn(hiddenNode)
						.nOut(numOutput)
						.activation(Activation.SOFTMAX)
						.lossFunction(LossFunction.MCXENT)//multi class cross entropy
						.build())
				.pretrain(false)
				.backprop(true)
				.build();
		
MultiLayerNetwork model = new MultiLayerNetwork(LSTMConf);
model.init();
System.out.println("TRAIN START...");
for(int i=0;i<numEpochs;i++) {
	model.fit(train);
}
		
System.out.println("EVALUATION START...");
Evaluation eval = new Evaluation(5);
for(DataSet row :test.asList()) {
	INDArray testdata = row.getFeatures();
	INDArray pred = model.output(testdata);
	eval.eval(row.getLabels(), pred);
}
System.out.println(eval.stats());

Sortie du résultat de l'évaluation

TRAIN START...
EVALUATION START...

Predictions labeled as 0 classified by model as 0: 80 times
Predictions labeled as 1 classified by model as 1: 31 times
Predictions labeled as 2 classified by model as 0: 3 times
Predictions labeled as 2 classified by model as 2: 34 times
Predictions labeled as 3 classified by model as 2: 1 times
Predictions labeled as 3 classified by model as 3: 36 times
Predictions labeled as 4 classified by model as 2: 4 times
Predictions labeled as 4 classified by model as 4: 12 times


==========================Scores========================================
 # of classes:    5
 Accuracy:        0.9602
 Precision:       0.9671
 Recall:          0.9284
 F1 Score:        0.9440
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
========================================================================

Impressions

Lorsque la quantité de caractéristiques dépasse plusieurs centaines, la régression logistique et SVM (linéaire) sont utilisées, mais j'ai pensé qu'il serait peut-être préférable d'utiliser LSTM. L'étude ne prend pas aussi longtemps que prévu. Je viens de l'essayer, mais je suis content de l'avoir essayé.

Matériel de référence

Pièce jointe (pom.xml)

pom.xml


<project xmlns="http://maven.apache.org/POM/4.0.0"
	xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<groupId>com.vis</groupId>
	<artifactId>CancerGenomeTest</artifactId>
	<version>0.0.1-SNAPSHOT</version>

	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<java.version>1.8</java.version>
		<nd4j.version>1.0.0-alpha</nd4j.version>
		<dl4j.version>1.0.0-alpha</dl4j.version>
		<datavec.version>1.0.0-alpha</datavec.version>
		<arbiter.version>1.0.0-alpha</arbiter.version>
		<logback.version>1.2.3</logback.version>
		<dl4j.spark.version>1.0.0-alpha_spark_2</dl4j.spark.version>
	</properties>

	<dependencies>
		<dependency>
			<groupId>org.nd4j</groupId>
			<artifactId>nd4j-native</artifactId>
			<version>${nd4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>dl4j-spark_2.11</artifactId>
			<version>${dl4j.spark.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-core</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-nlp</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>deeplearning4j-zoo</artifactId>
			<version>${dl4j.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-deeplearning4j</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.deeplearning4j</groupId>
			<artifactId>arbiter-ui_2.11</artifactId>
			<version>${arbiter.version}</version>
		</dependency>
		<dependency>
			<groupId>org.datavec</groupId>
			<artifactId>datavec-data-codec</artifactId>
			<version>${datavec.version}</version>
		</dependency>
		<dependency>
			<groupId>org.apache.httpcomponents</groupId>
			<artifactId>httpclient</artifactId>
			<version>4.3.5</version>
		</dependency>
		<dependency>
			<groupId>ch.qos.logback</groupId>
			<artifactId>logback-classic</artifactId>
			<version>${logback.version}</version>
		</dependency>
		<dependency>
			<groupId>com.fasterxml.jackson.core</groupId>
			<artifactId>jackson-annotations</artifactId>
			<version>2.11.0</version>
		</dependency>
	</dependencies>
</project>

Recommended Posts

LSTM prend en charge une grande quantité de fonctionnalités RNA-Seq
Résumé des fonctionnalités simples de Bootstrap pour les débutants