[JAVA] I tried Oracle's machine learning OSS "Tribuo"

Introduction

The other day, I saw the news that Oracle released a machine learning library in Java as open source, so I touched it lightly.

-CodeZine News --Oracle releases Java-based machine learning library "Tribuo" as open source -Mynavi News --Oracle Announces Java Machine Learning Library "Tribuo"

Looking at this, it seems that XGBoost etc. can be used in addition to the general algorithm of machine learning.

HeroAnimation_V2.gif

Feature

The top page of the official website has the following three features.

--Provenance: Since Tribuo's models, datasets, and evaluations have a history, the parameters used to create them, data conversion methods, files, etc. can be accurately tracked (*).

--Type-safe: Using Java, you can find mistakes at compile time instead of in production.

--Interoperable: Provides an interface to popular machine learning libraries such as XGBoost and Tensorflow. It supports the ONNX model exchange format and can deploy models built in other packages and languages (such as scikit-learn).

*: "Provenance" is information that shows how a model or dataset was created. I don't know if the translation is "history" and appropriate.

Move for the time being

There was a Tutorial for classifying irises, so let's try this first. In this tutorial, we will learn the data of iris with four characteristics (length and width of corolla and petals), and create and predict a model that classifies it into three types (versicolor, virginica, setosa).

iris.png

Create a Maven project in IntelliJ as shown below

Screenshot from 2020-09-23 22-04-30.png

Add the following to pom.xml.

<dependencies>
    <dependency>
        <groupId>org.tribuo</groupId>
        <artifactId>tribuo-all</artifactId>
        <version>4.0.0</version>
        <type>pom</type>
    </dependency>
</dependencies>

Download the iris flower data to classify,

wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

After that, as in the tutorial, when you create a class, the directory structure will be as follows.

Screenshot from 2020-09-24 08-50-26.png

I will upload the created sample project to my GitHub, so if you want to see the details of the source code, please refer to here.

Then, execute it immediately. However ... I should have implemented it according to the tutorial, but for some reason I got an error ...

Exception in thread "main" java.lang.IllegalArgumentException: On row 151 headers has 5 elements, current line has 1 elements.
	at org.tribuo.data.csv.CSVIterator.zip(CSVIterator.java:168)
	at org.tribuo.data.csv.CSVIterator.getRow(CSVIterator.java:188)
	at org.tribuo.data.columnar.ColumnarIterator.hasNext(ColumnarIterator.java:114)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:249)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:161)
	at ClassificationExample.main(ClassificationExample.java:21)

Process finished with exit code 1

The error message says "ʻOn row 151 ...` ", so when I checked the 151st line of the data file, there was only one line break in the last line of the data file ...

Screenshot from 2020-09-24 09-05-23.png

Delete the blank line while thinking "Don't do that" and execute again. This time it was successful, and 44 out of 45 verification data were correct. The correct answer rate was 97.8%.

Class                           n          tp          fn          fp      recall        prec          f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022
                   Iris-versicolor   Iris-virginica      Iris-setosa
Iris-versicolor                 16                0                0
Iris-virginica                   1               14                0
Iris-setosa                      0                0               14

Explanation of source code

As you can see by reading the tutorial, I will briefly explain the source code.

I created only one class with the main () method. Some classes etc. required for multi-class classification are ʻimport`.

import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
  ...(Abbreviation)...

public class ClassificationExample {
    public static void main(String[] args) throws IOException {

Load the downloaded data file with CSVLoader and hold the data in a class called ListDataSource.

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(labelFactory);

String[] irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
ListDataSource irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"), "species", irisHeaders);

This data is divided into training data and verification data at 7: 3.

TrainTestSplitter irisSplitter = new TrainTestSplitter<>(irisesSource, 0.7, 1L);

MutableDataset trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(irisSplitter.getTest());

You can use LogisticRegressionTrainer to learn with logistic regression.

Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);

By calling the LogisticRegressionTrainer.toString () method of the class, you can find out the value of the hyperparameter you are using as follows.

LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)

Evaluate how correct the validation data is with LabelEvaluator.evaluate ().

LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);

The result of the evaluation can be found by calling the toString () method of the LabelEvaluation class. You can also find the confusion matrix by calling the getConfusionMatrix () method.

System.out.println(evaluation);
System.out.println(evaluation.getConfusionMatrix());

This is the end of learning and evaluating the model, but the tutorial also describes how to obtain the "Provenance" mentioned above.

ModelProvenance provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));

The output of this code looks like this:

TrainTestSplitter(
	class-name = org.tribuo.evaluation.TrainTestSplitter
	source = CSVLoader(
			class-name = org.tribuo.data.csv.CSVLoader
			outputFactory = LabelFactory(
					class-name = org.tribuo.classification.LabelFactory
				)
			response-name = species
			separator = ,
			quote = "
			path = file:/home/tamura/git/tribuo-examples/bezdekIris.data
			file-modified-time = 2020-09-24T09:05:30+09:00
			resource-hash = 36F668D1CBC29A8C2C1128C5D2F0D400FA04ED4DC62D12246F44CE9360360CC0
		)
	train-proportion = 0.7
	seed = 1
	size = 150
	is-train = true
)

It seems that you can see which file was read and how it was divided into training data and verification data.

application

I would like to add / modify the source code a little and check the operation.

Usually, once you have a highly accurate model, you use it to predict unknown data. Tribuo can make predictions with Model.predict (), resulting in a Prediction object. This object contains probabilities that indicate what kind of iris it is:

I'd like to say, "Let's make a prediction," but since there is no unknown data, I will give verification data to this method. By implementing the following, you can identify the row where the prediction of the verification data is incorrect.

List<Example> data = testingDataset.getData(); //Data for verification
for (Example<Label> testingData : data) {
    Prediction<Label> predict = irisModel.predict(testingData); //Predict data for verification one by one
    String expectedResult = testingData.getOutput().getLabel(); //Correct answer
    String predictResult = predict.getOutput().getLabel(); //Predicted result
    if (!predictResult.equals(expectedResult)) {
        System.out.println("Expected result : " + expectedResult);
        System.out.println("Predicted result: " + predictResult);
        System.out.println(predict.getOutputScores());
    }
}

The output result is as follows.

Expected result : Iris-virginica
Predicted result: Iris-versicolor
{Iris-versicolor=(Iris-versicolor,0.5732799760841581), Iris-virginica=(Iris-virginica,0.42629863727592165), Iris-setosa=(Iris-setosa,4.213866399202189E-4)}

The wrong answer "virginica" was judged to be 57.3%, and the correct answer "versicolor" was judged to be 42.6%, so it can be said that one case that failed to predict was regrettable.

Next, try using the XGBoost mentioned above. All you have to do is change the Trainer from LogisticRegressionTrainer to XGBoostClassificationTrainer (although you also need ʻimport`, of course).

// Trainer<Label> trainer = new LogisticRegressionTrainer();
Trainer<Label> trainer = new XGBoostClassificationTrainer(2);

The 2 given to the constructor is the number of decision trees, and here the minimum value 2 is given.

The result did not change, and the correct answer rate was 97.8%.

Build

However, I will write the build and method for those who want to fix bugs and add functions as well as using it.

To build Tribuo, Java 8 or later and Maven 3.5 or later are required, so check that first.

$ mvn -version
Apache Maven 3.6.3
Maven home: /usr/share/maven
Java version: 1.8.0_265, vendor: Private Build, runtime: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.4.0-47-generic", arch: "amd64", family: "unix"

By the way, this article uses Ubuntu 20.04 as the OS. To build, just run the mvn clean package.

$ mvn clean package
... (Omitted) ...
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time:  03:54 min
[INFO] Finished at: 2020-09-23T16:17:34+09:00
[INFO] ------------------------------------------------------------------------

The build finished in less than 4 minutes, and about 350MB of disk space was added. I've tried Deeplearning4j before (https://qiita.com/tamura__246/items/3893ec292284c7128069), but it's much lighter than that (Deeplearning4j has tens of GB of free space after a few hours of build completion). It was gone).

Impressions

When I touched it lightly and looked at the source code, I got the impression that it was "easy for Java, but lacking in functionality". To be honest, you might think, "With Python, you can do various things more easily."

It may be still in the future, but since the speed of selection is fast in this field, it is a little doubtful whether we can survive in the future. I wish I had a clear reason to use Tribuo ... (It seems that models learned in the Python library can be used from Java programs using Tribuo, so there may be such uses). I look forward to future developments.

reference

Tribuo official website

Recommended Posts

I tried Oracle's machine learning OSS "Tribuo"
I tried Spring State machine
[Machine learning] I tried Object Detection with Create ML [Object detection]
I tried to summarize Java learning (1)
I tried time-saving management learning with Studyplus.
I tried to make a machine learning application with Dash (+ Docker) part3 ~ Practice ~
I tried to implement deep learning in Java
I tried Spring.
I tried tomcat
I tried youtubeDataApi.
I tried refactoring ①
I tried FizzBuzz.
I tried JHipster 5.1
I tried text extraction (OCR) in Ruby using Vision API (Trained Machine Learning Model)
[I tried] Spring tutorial
I tried running Autoware
I tried using Gson
I tried using TestNG
I tried Spring Batch
I tried using Galasa
I tried node-jt400 (Programs)
I tried node-jt400 (execute)
I tried node-jt400 (Transactions)
Object-oriented child !? I tried Deep Learning in Java (trial edition)
[Swift] I tried to implement the function of the vending machine
I tried to make a machine learning application with Dash (+ Docker) part2 ~ Basic way of writing Dash ~