The other day, I saw the news that Oracle released a machine learning library in Java as open source, so I touched it lightly.
-CodeZine News --Oracle releases Java-based machine learning library "Tribuo" as open source -Mynavi News --Oracle Announces Java Machine Learning Library "Tribuo"
Looking at this, it seems that XGBoost etc. can be used in addition to the general algorithm of machine learning.
The top page of the official website has the following three features.
--Provenance: Since Tribuo's models, datasets, and evaluations have a history, the parameters used to create them, data conversion methods, files, etc. can be accurately tracked (*).
--Type-safe: Using Java, you can find mistakes at compile time instead of in production.
--Interoperable: Provides an interface to popular machine learning libraries such as XGBoost and Tensorflow. It supports the ONNX model exchange format and can deploy models built in other packages and languages (such as scikit-learn).
*: "Provenance" is information that shows how a model or dataset was created. I don't know if the translation is "history" and appropriate.
There was a Tutorial for classifying irises, so let's try this first. In this tutorial, we will learn the data of iris with four characteristics (length and width of corolla and petals), and create and predict a model that classifies it into three types (versicolor, virginica, setosa).
Create a Maven project in IntelliJ as shown below
Add the following to pom.xml
.
<dependencies>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-all</artifactId>
<version>4.0.0</version>
<type>pom</type>
</dependency>
</dependencies>
Download the iris flower data to classify,
wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data
After that, as in the tutorial, when you create a class, the directory structure will be as follows.
I will upload the created sample project to my GitHub, so if you want to see the details of the source code, please refer to here.
Then, execute it immediately. However ... I should have implemented it according to the tutorial, but for some reason I got an error ...
Exception in thread "main" java.lang.IllegalArgumentException: On row 151 headers has 5 elements, current line has 1 elements.
at org.tribuo.data.csv.CSVIterator.zip(CSVIterator.java:168)
at org.tribuo.data.csv.CSVIterator.getRow(CSVIterator.java:188)
at org.tribuo.data.columnar.ColumnarIterator.hasNext(ColumnarIterator.java:114)
at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:249)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:161)
at ClassificationExample.main(ClassificationExample.java:21)
Process finished with exit code 1
The error message says "ʻOn row 151 ...` ", so when I checked the 151st line of the data file, there was only one line break in the last line of the data file ...
Delete the blank line while thinking "Don't do that" and execute again. This time it was successful, and 44 out of 45 verification data were correct. The correct answer rate was 97.8%.
Class n tp fn fp recall prec f1
Iris-versicolor 16 16 0 1 1.000 0.941 0.970
Iris-virginica 15 14 1 0 0.933 1.000 0.966
Iris-setosa 14 14 0 0 1.000 1.000 1.000
Total 45 44 1 1
Accuracy 0.978
Micro Average 0.978 0.978 0.978
Macro Average 0.978 0.980 0.978
Balanced Error Rate 0.022
Iris-versicolor Iris-virginica Iris-setosa
Iris-versicolor 16 0 0
Iris-virginica 1 14 0
Iris-setosa 0 0 14
As you can see by reading the tutorial, I will briefly explain the source code.
I created only one class with the main ()
method. Some classes etc. required for multi-class classification are ʻimport`.
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
...(Abbreviation)...
public class ClassificationExample {
public static void main(String[] args) throws IOException {
Load the downloaded data file with CSVLoader
and hold the data in a class called ListDataSource
.
LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(labelFactory);
String[] irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
ListDataSource irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"), "species", irisHeaders);
This data is divided into training data and verification data at 7: 3.
TrainTestSplitter irisSplitter = new TrainTestSplitter<>(irisesSource, 0.7, 1L);
MutableDataset trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(irisSplitter.getTest());
You can use LogisticRegressionTrainer
to learn with logistic regression.
Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);
By calling the LogisticRegressionTrainer.toString ()
method of the class, you can find out the value of the hyperparameter you are using as follows.
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)
Evaluate how correct the validation data is with LabelEvaluator.evaluate ()
.
LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);
The result of the evaluation can be found by calling the toString ()
method of the LabelEvaluation
class. You can also find the confusion matrix by calling the getConfusionMatrix ()
method.
System.out.println(evaluation);
System.out.println(evaluation.getConfusionMatrix());
This is the end of learning and evaluating the model, but the tutorial also describes how to obtain the "Provenance" mentioned above.
ModelProvenance provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));
The output of this code looks like this:
TrainTestSplitter(
class-name = org.tribuo.evaluation.TrainTestSplitter
source = CSVLoader(
class-name = org.tribuo.data.csv.CSVLoader
outputFactory = LabelFactory(
class-name = org.tribuo.classification.LabelFactory
)
response-name = species
separator = ,
quote = "
path = file:/home/tamura/git/tribuo-examples/bezdekIris.data
file-modified-time = 2020-09-24T09:05:30+09:00
resource-hash = 36F668D1CBC29A8C2C1128C5D2F0D400FA04ED4DC62D12246F44CE9360360CC0
)
train-proportion = 0.7
seed = 1
size = 150
is-train = true
)
It seems that you can see which file was read and how it was divided into training data and verification data.
I would like to add / modify the source code a little and check the operation.
Usually, once you have a highly accurate model, you use it to predict unknown data. Tribuo can make predictions with Model.predict ()
, resulting in a Prediction
object. This object contains probabilities that indicate what kind of iris it is:
I'd like to say, "Let's make a prediction," but since there is no unknown data, I will give verification data to this method. By implementing the following, you can identify the row where the prediction of the verification data is incorrect.
List<Example> data = testingDataset.getData(); //Data for verification
for (Example<Label> testingData : data) {
Prediction<Label> predict = irisModel.predict(testingData); //Predict data for verification one by one
String expectedResult = testingData.getOutput().getLabel(); //Correct answer
String predictResult = predict.getOutput().getLabel(); //Predicted result
if (!predictResult.equals(expectedResult)) {
System.out.println("Expected result : " + expectedResult);
System.out.println("Predicted result: " + predictResult);
System.out.println(predict.getOutputScores());
}
}
The output result is as follows.
Expected result : Iris-virginica
Predicted result: Iris-versicolor
{Iris-versicolor=(Iris-versicolor,0.5732799760841581), Iris-virginica=(Iris-virginica,0.42629863727592165), Iris-setosa=(Iris-setosa,4.213866399202189E-4)}
The wrong answer "virginica" was judged to be 57.3%, and the correct answer "versicolor" was judged to be 42.6%, so it can be said that one case that failed to predict was regrettable.
Next, try using the XGBoost mentioned above. All you have to do is change the Trainer
from LogisticRegressionTrainer
to XGBoostClassificationTrainer
(although you also need ʻimport`, of course).
// Trainer<Label> trainer = new LogisticRegressionTrainer();
Trainer<Label> trainer = new XGBoostClassificationTrainer(2);
The 2
given to the constructor is the number of decision trees, and here the minimum value 2
is given.
The result did not change, and the correct answer rate was 97.8%.
However, I will write the build and method for those who want to fix bugs and add functions as well as using it.
To build Tribuo, Java 8 or later and Maven 3.5 or later are required, so check that first.
$ mvn -version
Apache Maven 3.6.3
Maven home: /usr/share/maven
Java version: 1.8.0_265, vendor: Private Build, runtime: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.4.0-47-generic", arch: "amd64", family: "unix"
By the way, this article uses Ubuntu 20.04 as the OS. To build, just run the mvn clean package
.
$ mvn clean package
... (Omitted) ...
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 03:54 min
[INFO] Finished at: 2020-09-23T16:17:34+09:00
[INFO] ------------------------------------------------------------------------
The build finished in less than 4 minutes, and about 350MB of disk space was added. I've tried Deeplearning4j before (https://qiita.com/tamura__246/items/3893ec292284c7128069), but it's much lighter than that (Deeplearning4j has tens of GB of free space after a few hours of build completion). It was gone).
When I touched it lightly and looked at the source code, I got the impression that it was "easy for Java, but lacking in functionality". To be honest, you might think, "With Python, you can do various things more easily."
It may be still in the future, but since the speed of selection is fast in this field, it is a little doubtful whether we can survive in the future. I wish I had a clear reason to use Tribuo ... (It seems that models learned in the Python library can be used from Java programs using Tribuo, so there may be such uses). I look forward to future developments.
Recommended Posts