[Java] Support for enormous RNA-Seq features with LSTM

4 minute read

Why did you make an article

Unexpectedly, there was a person who read my disorganized article, so I am writing once in a while the earth is one → all humanity brothers → my workmates’ self-consciousness. (I will also benefit myself.)

What I tried

-Use LSTM for model layers -Predict cancer genes (5 types) by multiclass classification using RNA-Seq data (801 x 20531)


  • Laptop (general)
  • Optional: NVIDIA-GPU (1050Ti this time) (No AMD, dl4j relies on CUDA)


  • ubuntu 18.04 (because it is java, the OS does not matter)
  • maven + dl4j related
  • eclipse (2018)
  • JDK8
  • (Already, people who are interested in such a theme are using python, isn’t it?)


  • https://archive.ics.uci.edu/ml/datasets/gene+expression+cancer+RNA-Seq

This data is a part of the data randomly extracted from the data set obtained by using the high-performance RNA analyzer called HiSeq from Illumina in the cancer genome atlas pan-cancer analysis project. For more information, https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3919969/

RNA-Seq is a numerical expression of how much a gene is likely to be expressed relative to the reference sequence, arranged side by side. The reference sequence is a gene pattern that researchers decide when the pattern of the gene is known to some extent. Array patterns are attracting attention because they may explain the properties and attributes of individuals. In particular, the search for RNA in cancer cells has long been performed in drug discovery research. According to people familiar with the work, even if the human genome is elucidated, there is still a lot of work to do.

What to classify

The following cancer types are considered in this study, but the sample data includes five of these (★).

  • LUSC Lung Squamous Carcinoma
  • READ Rectal Adeno Carcinoma
  • GBM Glioblastoma Multiform
  • LAML Lymphoblastic Acute Myeloid Leukemia
  • HNSC Head and Neck Squamous Carcinoma
  • BLCA Bladder Carcinoma
  • KIRC Kidney Renal Clear Cell Carcinoma (★)
  • UCEC Uterine Cervical and Endometrial Carcinoma
  • LUAD Lung Adenocarcinoma (★)
  • OV Cvarian Carcinoma
  • BRCA Breast Carcinoma (★)
  • COAD Colon Adenocarcinoma (★)
  • PRAD Prostate Adenocarcinoma (★)

The five types of cancer are the classification targets.

Preparing the project

Create a maven project and create an arbitrary class file.

Prepare CSV

// Read csv
File dataFile = new File("./TCGA-PANCAN-HiSeq-801x20531/data.csv");
File labelFile = new File("./TCGA-PANCAN-HiSeq-801x20531/labels.csv");

Extract data

int numClasses = 5; //5 classes
int batchSize = 801; //samples total

//Get the contents from the data body first
RecordReader reader = new CSVRecordReader(1,',');//skip header
try {
reader.initialize(new FileSplit(dataFile));
} catch (IOException | InterruptedException e) {

double[][] dataObj = new double[batchSize][];
int itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = new double[row.size()-1];
for(int i = 0; i <row.size()-1; i++) {
if(i == 0) {//skip subject
double scaler = Double.parseDouble(new ConvertToString().map(row.get(i)).toString());
scalers[i] = scaler;
dataObj[itr] = scalers;
System.out.println("Data samples "+ +dataObj.length);//801

// load label
// also convert for multi-label
try {
reader = new CSVRecordReader(1,',');//skip header
reader.initialize(new FileSplit(labelFile));
} catch (IOException | InterruptedException e) {
double[][] labels = new double[batchSize][];
itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = null;
for(int i = 0; i <row.size(); i++) {
if(i == 0) {//skip subject
// Class
if(i == 1) {
String classname = new ConvertToString().map(row.get(i)).toString();
switch(classname) {
case "BRCA":
scalers = new double[]{1,0,0,0,0};
case "PRAD":
scalers = new double[]{0,1,0,0,0};
case "LUAD":
scalers = new double[]{0,0,1,0,0};
case "KIRC":
scalers = new double[]{0,0,0,1,0};
case "COAD":
scalers = new double[]{0,0,0,0,1};
labels[itr] = scalers;
System.out.println("LABEL :"+labels.length);//801

Once in the INDArray to turn the data into a DataSet object

//Create a DataSet
INDArray dataArray = Nd4j.create(dataObj,'c');
INDArray labelArray = Nd4j.create(labels,'c');

//Rank: 2,Offset: 0
// Order: c Shape: [801,20531], stride: [20531,1]
//Rank: 2,Offset: 0
// Order: c Shape: [801,5], stride: [5,1]

To #DataSet

DataSet dataset = new DataSet(dataArray, labelArray);
SplitTestAndTrain sp = dataset.splitTestAndTrain(600, new Random(42L)); //600 train, 201 test
DataSet train = sp.getTrain();
DataSet test = sp.getTest();

//{0=220.0, 1=105.0, 2=104.0, 3=109.0, 4=62.0}
//{0=80.0, 1=31.0, 2=37.0, 3=37.0, 4=16.0}

Model building, training and evaluation

int numInput = 20531;
int numOutput = numClasses;
int hiddenNode = 500; //powerless
int numEpochs = 50;
It's a sequel.
MultiLayerConfiguration LSTMConf = new NeuralNetConfiguration.Builder()
.updater(new Adam(0.001))
.layer(0,new LSTM.Builder()
.layer(1,new LSTM.Builder()
.layer(2,new LSTM.Builder()
    .layer(3,new RnnOutputLayer.Builder()
      .lossFunction(LossFunction.MCXENT)//multi class cross entropy
MultiLayerNetwork model = new MultiLayerNetwork(LSTMConf);
System.out.println("TRAIN START...");
for(int i=0;i<numEpochs;i++) {
System.out.println("EVALUATION START...");
Evaluation eval = new Evaluation(5);
for(DataSet row :test.asList()) {
 INDArray testdata = row.getFeatures();
 INDArray pred = model.output(testdata);
 eval.eval(row.getLabels(), pred);



Predictions labeled as 0 classified by model as 0: 80 times
Predictions labeled as 1 classified by model as 1: 31 times
Predictions labeled as 2 classified by model as 0: 3 times
Predictions labeled as 2 classified by model as 2: 34 times
Predictions labeled as 3 classified by model as 2: 1 times
Predictions labeled as 3 classified by model as 3: 36 times
Predictions labeled as 4 classified by model as 2: 4 times
Predictions labeled as 4 classified by model as 4: 12 times

 # of classes:    5
 Accuracy:        0.9602
 Precision:       0.9671
 Recall:          0.9284
 F1 Score:        0.9440
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)


特徴量が数百を超えるときは、ロジスティック回帰やSVM(リニア)が用いられるが、LSTMを使うのも良いのかもしれないと思った。 学習の時間も思ったほどかからない。 やってみただけだけれど、試しといてよかった。


  • Java Deep Learning Projects: Implement 10 real-world deep learning applications using Deeplearning4j and open source APIs



<project xmlns="http://maven.apache.org/POM/4.0.0"
 xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">