Dare to challenge Kaggle with Java (1)

Overview

Is it possible to fight in Kaggle using only Java? I dared to try it.

What to do in this article

I would like to implement Kaggle's well-known challenge for beginners Titanic: Machine Learning from Disaster in Java only. .. The challenge is to predict survival based on the name, gender, age, ticket information, etc. of the customers on board Titanic.

titanic.png

It builds a model learned from the given training data (train.csv), predicts the survival of the person included in the test data (test.csv), and competes for the correct answer rate of the result. Is this competition (though I think that predicting life or death like a game ...).

First, let's predict the survivors with a minimal implementation.

--Building a development environment --Read library data --Building a baseline model --Output / submission of forecast

Then implement the following:

--Data analysis --Data preprocessing --Feature engineering --Building some models --Cross validation and grid search --Hyperparameter tuning --Model ensemble

Java machine learning library

And before that, what are Java's machine learning libraries in the first place? I think the following are famous ones.

-Apache Mahout: Machine learning library since 2009 -Deeplearning4j: A library specializing in deep learning -Tribuo: A machine learning library released by Oracle the other day.

This time, I will try using Tribuo from this.

Predict survivors with minimal implementation

Construction of development environment

Since it is Java, let's build a survivor prediction model with IntelliJ for the time being.

Create a Maven project in IntelliJ as shown below

Screenshot from 2020-09-23 22-04-30.png

Add the following to pom.xml.

<dependencies>
    <dependency>
        <groupId>org.tribuo</groupId>
        <artifactId>tribuo-all</artifactId>
        <version>4.0.0</version>
        <type>pom</type>
    </dependency>
</dependencies>

Next, click the "Download All" button on this page to download the data required for forecasting, unzip it, and then create the Maven. Copy it to your project's directory.

The directory structure is as follows.

Screenshot from 2020-10-07 22-04-01.png

Library data reading

Let's read the CSV file first. Implement and execute the following:

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train.csv"),"Survived");

But the NumberFormatException is ...

Exception in thread "main" java.lang.NumberFormatException: For input string: "S"
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:2043)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:23)

When I read the source code of Tribuo, it seems that this behavior can be changed with the implementation that assumes that the CSV file consists only of numbers (always Double.parseDouble ()). There is none. Perhaps Tribuo's current design philosophy is that preprocessing data is outside the scope of responsibility.

At least you need to remove the non-numeric columns in the CSV file. You can operate CSV files with Apache Commons CSV etc., but in anticipation of the future, we will introduce "DFLib" which is a library that seems to be able to perform preprocessing. DFLib is a lightweight Java implementation of Pandas in Python that uses Apache Commons CSV internally.

<dependency>
    <groupId>com.nhl.dflib</groupId>
    <artifactId>dflib-csv</artifactId>
    <version>0.8</version>
</dependency>

Before loading the CSV file with CSVLoader, delete the CSV" Name "," Sex "," Ticket "," Cabin ", and" Embarked "columns (narrow down to the required columns) as shown below. , Save to CSV file.

DataFrame df = Csv.loader().load("titanic/train.csv");
DataFrame selectColumns = df.selectColumns("Survived", "Pclass", "Age", "SibSp", "Parch", "Fare");
Csv.save(selectColumns, "titanic/train_removed.csv");

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(',',labelFactory);
ListDataSource dataource = csvLoader.loadDataSource(Paths.get("titanic/train_removed.csv"),"Survived");

Try again.

Exception in thread "main" java.lang.NumberFormatException: empty String
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:31)

But again, NumberFormatException. It seems that the missing value contained in the "Age" column is treated as an empty string. Since DataFrame has a method calledfillNulls (), you can replace null with zero etc. at once, but the read missing value seems to be interpreted as an empty string (" "), so the intention It doesn't work (although there may be a problem loading the CSV file ...). So, here we also delete the "Age" column.

DataFrame selectedDataFrame = df.selectColumns("Survived", "Pclass", "SibSp", "Parch", "Fare");

Building a baseline model

This time it worked. Next, we will build a model and train it. The minimum implementation when using logistic regression is as follows:

TrainTestSplitter dataSplitter = new TrainTestSplitter<>(dataource, 0.7, 1L);
MutableDataset trainingDataset = new MutableDataset<>(dataSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(dataSplitter.getTest());

Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);

LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);
System.out.println(evaluation);

I finally got the result. Of the 268 verification data, 163 are correct, with a correct answer rate of 60.8%.

Class                           n          tp          fn          fp      recall        prec          f1
0                             170          89          81          24       0.524       0.788       0.629
1                              98          74          24          81       0.755       0.477       0.585
Total                         268         163         105         105
Accuracy                                                                    0.608
Micro Average                                                               0.608       0.608       0.608
Macro Average                                                               0.639       0.633       0.607
Balanced Error Rate                                                         0.361

Output / submission of forecast

Now that the model is built, let's read the test data and make a prediction. Let's implement and check as follows.

DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
Csv.save(selectedDfTest, "titanic/test_removed.csv");

ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);

However, CsvLoader.loadDataSource () seems to require the name of the objective variable in the second argument, and I passed" Survived ", but I get an error if test.csv does not have that" Survived ". I have.

Exception in thread "main" java.lang.IllegalArgumentException: Response Survived not found in file file:/home/tamura/git/tribuo-examples/titanic/test_removed.csv
	at org.tribuo.data.csv.CSVLoader.validateResponseNames(CSVLoader.java:286)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:244)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:74)

While complaining "Why should the objective variable be included in the CSV file", there is no help for it, so add the "Survived" column in the csv file for submission (gender_submission.csv) to DataFrame. , I decided to deceive.

DataFrame dfTest = Csv.loader().load("titanic/test.csv");
DataFrame dfSubmission = Csv.loader().load("titanic/gender_submission.csv");
DataFrame selectedDfTest = dfTest.selectColumns("Pclass", "SibSp", "Parch", "Fare");
selectedDfTest = selectedDfTest.hConcat(dfSubmission.dropColumns("PassengerId"));
Csv.save(selectedDfTest, "titanic/test_removed.csv");

ListDataSource dataource4test = csvLoader.loadDataSource(Paths.get("titanic/test_removed.csv"),"Survived");
List<Prediction> predicts = model.predict(dataource4test);
System.out.println(predicts);

When I run it thinking this time, the familiar exception is ...

Exception in thread "main" java.lang.NumberFormatException: empty String
	at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1842)
	at sun.misc.FloatingDecimal.parseDouble(FloatingDecimal.java:110)
	at java.lang.Double.parseDouble(Double.java:538)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:260)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:184)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:138)
	at TitanicSurvivalClassifier.main(TitanicSurvivalClassifier.java:73)

Looking at test.csv, there is only one missing value in the "Fare" column ... This missing value is also an empty string (" "), not null, so DataFrame.fillNulls () cannot be used. .. As a result of examining the source code, I found that the empty string can be replaced by the following writing method.

selectedDfTest = selectedDfTest.convertColumn("Fare", s -> "".equals(s) ? "0": s);

The predicted result is now output.

[Prediction(maxLabel=(0,0.5474041397777752),outputScores={0=(0,0.5474041397777752)1=(1,0.4525958602222247}), Prediction(maxLabel=(0,0.6969779586356148),outputScores={0=(0,0.6969779586356148)1=(1,0.303022041364385}), Prediction(maxLabel=(1,0.5302004352989867),outputScores={0=(0,0.46979956470101314)1=(1,0.530200435298986}), Prediction(maxLabel=(0,0.52713643586377),outputScores={0=(0,0.52713643586377)1=(1,0.4728635641362}), Prediction(maxLabel=(0,0.5071805368465395),outputScores={0=(0,0.5071805368465395)1=(1,0.492819463153460}), Prediction(maxLabel=(0,0.5134002908191431),outputScores={0=(0,0.5134002908191431)1=(1,0.4865997091808569}),
・ ・ ・

All you have to do is save it in a CSV file for submission. It's easy to implement using Python, but it's not surprisingly easy to implement with DFLib or Tribuo, and I've looked at the documentation and source code for a while, but it may take some time, so the Java standard I implemented it in API as follows.

AtomicInteger counter = new AtomicInteger(891);
StringBuilder sb = new StringBuilder();
predicts.stream().forEach(p -> sb.append(String.valueOf(counter.addAndGet(1) + "," + p.getOutput().toString().substring(1,2)) + "\n"));
try (FileWriter fw = new FileWriter("titanic/submission.csv");){
    fw.write("PassengerId,Survived\n");
    fw.write(sb.toString());
} catch (IOException ex) {
    ex.printStackTrace();
}

Since the second character of p.getOutput (). ToString () is the predicted value (0 or 1), it is a bad implementation to take it out and write it with java.io.FileWriter. By the way, counter is 892 to 1309 of" PassengerId "included in the file to be submitted.

Now that the CSV file for submission has been output, upload it to Kaggle's site and check the score.

Screenshot from 2020-10-02 17-27-35.png

The score on Kaggle was 0.56220. It's low, but for the time being, I was able to upload it.

About the sequel

This time I tried only the minimum of outputting a CSV file for submission. Next time, I would like to verify whether data can be visualized on Jupyter Notebook, preprocessing, tuning, ensembling, etc.

Supplement

In this implementation, the CSV file was loaded using CSVLoader. I took it for granted to use this class when loading CSV files. However, according to the Tribuo developer, it is actually appropriate to use RowProcessor and CSVDataSource instead of CSVLoader in this case. Implementing the following will give the same result as the code above.

Tokenizer tokenizer = new BreakIteratorTokenizer(Locale.US);
LabelFactory labelFactory = new LabelFactory();
ResponseProcessor<Label> responseProcessor = new FieldResponseProcessor<>("Survived","0",labelFactory);
Map<String, FieldProcessor> fieldProcessors = new HashMap<>();
fieldProcessors.put("Pclass", new DoubleFieldProcessor("Pclass"));
fieldProcessors.put("SibSp", new DoubleFieldProcessor("SibSp"));
fieldProcessors.put("Parch", new DoubleFieldProcessor("Parch"));
fieldProcessors.put("Fare", new DoubleFieldProcessor("Fare"));

RowProcessor<Label> rp = new RowProcessor<>(responseProcessor,fieldProcessors);
Path path = Paths.get("titanic/train.csv");
CSVDataSource<Label> source = new CSVDataSource<>(path,rp,true);
TrainTestSplitter dataSplitter = new TrainTestSplitter<>(source, 0.7, 1L);

Furthermore, by using ʻIdentity Processorfor the text string, it seems that One-hot-encoding is done automatically. For example, if you apply it to the "Sex" column, which has two values,male and female`,

fieldProcessors.put("Sex", new IdentityProcessor("Sex"));

The Sex @ male and Sex @ female columns with a value of 0/1 are added (although one column is sufficient as there are two choices). This also increases the accuracy to 0.746.

Similarly, if you apply it to the "Embarked" column, which has three values, C, S, and K,

fieldProcessors.put("Embarked", new IdentityProcessor("Embarked"));

I would like to say that the accuracy will increase further, but in reality it will decrease. This is because there is only one row with an empty value in the "Embarked" column, and as a result, the number of columns of the training data and the verification data will be different and it will not work properly. I'd like to replace the empty value with the mode S, but it's hard to realize it quickly like Python ... I'll include that again next time.

Impressions / conclusions

My answer to the question "Can I fight in Kaggle using only Java?" Is "quite difficult at this point" (although I knew the conclusion from the beginning). Since Tribuo has just been released, there is very little information on the Internet, and there are not enough guides and Javadocs, so it takes a lot of time to investigate. And, in the first place, the function is not enough. There isn't even a method that fills the null value in the data with the median.

The current situation is that things that can be easily done with Python cannot be done easily, but conversely, there are many missing functions, so I think there are many opportunities to contribute (pull request). If you are a Java programmer interested in machine learning, why don't you contribute to this project as well as studying? I'm just starting to touch it, so I'll investigate a little more to see how much I can do.

Recommended Posts

Dare to challenge Kaggle with Java (1)
Challenge to deal with garbled characters with Java AudioSystem.getMixerInfo ()
Connect to DB with Java
Connect to MySQL 8 with Java
Java to learn with ramen [Part 1]
[Java] Points to note with Arrays.asList ()
I tried to interact with Java
Java, arrays to start with beginners
A story that I struggled to challenge a competition professional with Java
How to compile Java with VsCode & Ant
[Java] How to compare with equals method
If you dare to compare Integer with "==" ...
Introduction to algorithms with java --Search (depth-first search)
Easy to trip with Java regular expressions
Introduction to algorithms with java --Search (breadth-first search)
[Java] Introduction to Java
Introduction to java
[Java] How to test for null with JUnit
I tried to make Basic authentication with Java
Introduction to algorithms with java --Search (bit full search)
Deploy Java web app to Azure with maven
How to use Java framework with AWS Lambda! ??
I want to use java8 forEach with index
How to use Java API with lambda expression
Getting started with Kotlin to send to Java developers
Try to implement TCP / IP + NIO with JAVA
[Java] Article to add validation with Spring Boot 2.3.1.
Easy to make LINE BOT with Java Servlet
I tried to break a block with java (1)
Install java with Homebrew
How to call functions in bulk with Java reflection
List processing to understand with pictures --java8 stream / javaslang-
Submit a job to AWS Batch with Java (Eclipse)
Changes from Java 8 to Java 11
Sum from Java_1 to 100
I tried to implement TCP / IP + BIO with JAVA
Change seats with java
Install Java with Ansible
[Java 11] I tried to execute Java without compiling with javac
[Java] How to omit spring constructor injection with Lombok
How to deploy Java to AWS Lambda with Serverless Framework
[Java] Connect to MySQL
[Java] How to encrypt with AES encryption with standard library
Comfortable download with JAVA
[Java] Refer to and set private variables with reflection
Switch java with direnv
Kotlin's improvements to Java
From Java to Ruby !!
HTTPS connection with Java to the self-signed certificate server
I tried to implement Stalin sort with Java Collector
Download Java with Ansible
I want to transition screens with kotlin and java!
Try connecting to AzureCosmosDB Emulator for Docker with Java
How to build Java development environment with VS Code
Let's scrape with Java! !!
Introduction to java command
Build Java with Wercker
I want to get along with Map [Java beginner]
I used to make nc (netcat) with JAVA normally
[Java] How to start a new line with StringBuilder
Endian conversion with JAVA