Étonnamment, il y avait une personne qui a lu mon article désorganisé, donc il n'y a qu'une seule terre → tous les êtres humains frères → un sentiment de collègue bourgeoise tout seul, et j'écris de temps en temps. (Aussi, c'est bon pour moi.)
--Utilisez LSTM pour les couches de modèle
--ubuntu 18.04 (Puisqu'il s'agit de java, le système d'exploitation n'a pas d'importance en détail) --maven + dl4j liés
Ces données font partie des données extraites au hasard de l'ensemble de données acquises par le projet d'analyse pan-cancer de l'atlas du génome du cancer à l'aide d'un analyseur d'ARN haute performance appelé HiSeq d'Illumina. Pour plus d'informations, https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3919969/
RNA-Seq est un arrangement côte à côte qui quantifie la quantité d'expressions d'un gène par rapport à la séquence de référence pour chaque gène. La séquence de référence est un modèle de gène que le chercheur décide lorsque le modèle de gène est connu dans une certaine mesure. Le modèle de séquence attire l'attention car il peut expliquer les propriétés et les attributs des individus. En particulier, la recherche d'ARN dans les cellules cancéreuses est menée depuis longtemps dans la recherche de découverte de médicaments. Selon une personne familiarisée avec le sujet, même si le génome humain est élucidé, il reste encore beaucoup de travail à faire.
Les types de cancer suivants sont examinés dans cette étude, et l'échantillon de données comprend cinq (★) d'entre eux.
Ces cinq types de cancer sont classés.
Créez un projet maven et créez un fichier de classe arbitraire.
//Chargement de csv
File dataFile = new File("./TCGA-PANCAN-HiSeq-801x20531/data.csv");
File labelFile = new File("./TCGA-PANCAN-HiSeq-801x20531/labels.csv");
int numClasses = 5; //5 classes
int batchSize = 801; //samples total
//Obtenez d'abord le contenu du corps de données
RecordReader reader = new CSVRecordReader(1,',');//skip header
try {
reader.initialize(new FileSplit(dataFile));
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
double[][] dataObj = new double[batchSize][];
int itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = new double[row.size()-1];
for(int i = 0; i < row.size()-1; i++) {
if(i == 0) {//skip subject
continue;
}
double scaler = Double.parseDouble(new ConvertToString().map(row.get(i)).toString());
scalers[i] = scaler;
}
dataObj[itr] = scalers;
itr++;
}
System.out.println("Data samples "+ +dataObj.length);//801
//Lire l'étiquette
//Convertissez également pour multi-étiquettes
//label
try {
reader = new CSVRecordReader(1,',');//skip header
reader.initialize(new FileSplit(labelFile));
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
double[][] labels = new double[batchSize][];
itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = null;
for(int i = 0; i < row.size(); i++) {
if(i == 0) {//skip subject
continue;
}
// Class
if(i == 1) {
String classname = new ConvertToString().map(row.get(i)).toString();
switch(classname) {
case "BRCA":
scalers = new double[]{1,0,0,0,0};
break;
case "PRAD":
scalers = new double[]{0,1,0,0,0};
break;
case "LUAD":
scalers = new double[]{0,0,1,0,0};
break;
case "KIRC":
scalers = new double[]{0,0,0,1,0};
break;
case "COAD":
scalers = new double[]{0,0,0,0,1};
break;
default:
break;
}
labels[itr] = scalers;
itr++;
}
}
}
System.out.println("LABEL : "+labels.length);//801
//Créer un DataSet
INDArray dataArray = Nd4j.create(dataObj,'c');
System.out.println(dataArray.shapeInfoToString());
INDArray labelArray = Nd4j.create(labels,'c');
System.out.println(labelArray.shapeInfoToString());
//Rank: 2,Offset: 0
// Order: c Shape: [801,20531], stride: [20531,1]
//Rank: 2,Offset: 0
// Order: c Shape: [801,5], stride: [5,1]
DataSet dataset = new DataSet(dataArray, labelArray);
SplitTestAndTrain sp = dataset.splitTestAndTrain(600, new Random(42L));//600 train, 201 test
DataSet train = sp.getTrain();
DataSet test = sp.getTest();
System.out.println(train.labelCounts());
System.out.println(test.labelCounts());
//{0=220.0, 1=105.0, 2=104.0, 3=109.0, 4=62.0}
//{0=80.0, 1=31.0, 2=37.0, 3=37.0, 4=16.0}
//MODEL TRAIN AND EVALUATION
int numInput = 20531;
int numOutput = numClasses;
int hiddenNode = 500;//Impuissant
int numEpochs = 50;
MultiLayerConfiguration LSTMConf = new NeuralNetConfiguration.Builder()
.seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.001))
.list()
.layer(0,new LSTM.Builder()
.nIn(numInput)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(1,new LSTM.Builder()
.nIn(hiddenNode)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(2,new LSTM.Builder()
.nIn(hiddenNode)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(3,new RnnOutputLayer.Builder()
.nIn(hiddenNode)
.nOut(numOutput)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunction.MCXENT)//multi class cross entropy
.build())
.pretrain(false)
.backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(LSTMConf);
model.init();
System.out.println("TRAIN START...");
for(int i=0;i<numEpochs;i++) {
model.fit(train);
}
System.out.println("EVALUATION START...");
Evaluation eval = new Evaluation(5);
for(DataSet row :test.asList()) {
INDArray testdata = row.getFeatures();
INDArray pred = model.output(testdata);
eval.eval(row.getLabels(), pred);
}
System.out.println(eval.stats());
TRAIN START...
EVALUATION START...
Predictions labeled as 0 classified by model as 0: 80 times
Predictions labeled as 1 classified by model as 1: 31 times
Predictions labeled as 2 classified by model as 0: 3 times
Predictions labeled as 2 classified by model as 2: 34 times
Predictions labeled as 3 classified by model as 2: 1 times
Predictions labeled as 3 classified by model as 3: 36 times
Predictions labeled as 4 classified by model as 2: 4 times
Predictions labeled as 4 classified by model as 4: 12 times
==========================Scores========================================
# of classes: 5
Accuracy: 0.9602
Precision: 0.9671
Recall: 0.9284
F1 Score: 0.9440
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
========================================================================
Lorsque la quantité de caractéristiques dépasse plusieurs centaines, la régression logistique et SVM (linéaire) sont utilisées, mais j'ai pensé qu'il serait peut-être préférable d'utiliser LSTM. L'étude ne prend pas aussi longtemps que prévu. Je viens de l'essayer, mais je suis content de l'avoir essayé.
pom.xml
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.vis</groupId>
<artifactId>CancerGenomeTest</artifactId>
<version>0.0.1-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
<nd4j.version>1.0.0-alpha</nd4j.version>
<dl4j.version>1.0.0-alpha</dl4j.version>
<datavec.version>1.0.0-alpha</datavec.version>
<arbiter.version>1.0.0-alpha</arbiter.version>
<logback.version>1.2.3</logback.version>
<dl4j.spark.version>1.0.0-alpha_spark_2</dl4j.spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
<version>${dl4j.spark.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-ui_2.11</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-codec</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.3.5</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>2.11.0</version>
</dependency>
</dependencies>
</project>