[JAVA] J'ai essayé le machine learning OSS d'Oracle "Tribuo"

introduction

L'autre jour, j'ai vu la nouvelle qu'Oracle avait publié une bibliothèque d'apprentissage automatique en Java en tant que source ouverte, alors je l'ai touchée légèrement.

En regardant cela, il semble que XGBoost, etc. puisse être utilisé en plus de l'algorithme général d'apprentissage automatique.

HeroAnimation_V2.gif

Fonctionnalité

La page supérieure du site Web officiel présente les trois caractéristiques suivantes.

--Interopérable: fournit une interface aux bibliothèques de machine learning populaires telles que XGBoost et Tensorflow. Il prend en charge le format d'échange de modèles ONNX et peut déployer des modèles intégrés dans d'autres packages et langages (tels que scikit-learn).

*: "Provenance" est une information qui montre comment un modèle ou un ensemble de données a été créé. Je ne sais pas si la traduction est «historique» et appropriée.

Bougez pour le moment

Il y avait un Tutoriel pour classer les iris, alors essayons ceci en premier. Dans ce tutoriel, nous allons apprendre les données des iris avec quatre caractéristiques (longueur et largeur des pétales et pétales), et créer et prédire un modèle qui les classe en trois types (versicolor, virginica, setosa).

iris.png

Créez un projet Maven avec IntelliJ comme indiqué ci-dessous

Screenshot from 2020-09-23 22-04-30.png

Ajoutez ce qui suit à pom.xml.

<dependencies>
    <dependency>
        <groupId>org.tribuo</groupId>
        <artifactId>tribuo-all</artifactId>
        <version>4.0.0</version>
        <type>pom</type>
    </dependency>
</dependencies>

Téléchargez les données de fleur d'iris pour classer

wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

Après cela, comme dans le didacticiel, lorsque vous créez une classe, la structure des répertoires sera la suivante.

Screenshot from 2020-09-24 08-50-26.png

L'exemple de projet que j'ai créé sera téléchargé sur mon GitHub, donc si vous voulez voir les détails du code source, veuillez vous référer à ici.

Ensuite, exécutez-le immédiatement. Cependant ... j'aurais dû l'implémenter selon le tutoriel, mais pour une raison quelconque, j'ai eu une erreur ...

Exception in thread "main" java.lang.IllegalArgumentException: On row 151 headers has 5 elements, current line has 1 elements.
	at org.tribuo.data.csv.CSVIterator.zip(CSVIterator.java:168)
	at org.tribuo.data.csv.CSVIterator.getRow(CSVIterator.java:188)
	at org.tribuo.data.columnar.ColumnarIterator.hasNext(ColumnarIterator.java:114)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:249)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:161)
	at ClassificationExample.main(ClassificationExample.java:21)

Process finished with exit code 1

Le message d'erreur dit «À la ligne 151 ...» », donc quand j'ai vérifié la 151e ligne du fichier de données, il n'y avait qu'un seul saut de ligne dans la dernière ligne du fichier de données ...

Screenshot from 2020-09-24 09-05-23.png

Supprimez la ligne vide en pensant «Ne faites pas ça» et exécutez à nouveau. Cette fois, elle a réussi, et 44 des 45 données de vérification étaient correctes. Le taux de réponse correcte était de 97,8%.

Class                           n          tp          fn          fp      recall        prec          f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022
                   Iris-versicolor   Iris-virginica      Iris-setosa
Iris-versicolor                 16                0                0
Iris-virginica                   1               14                0
Iris-setosa                      0                0               14

Explication du code source

Comme vous pouvez le voir en lisant le tutoriel, je vais expliquer brièvement le code source.

J'ai créé une seule classe avec la méthode main (). Certaines classes, etc. requises pour la classification multi-classes sont ʻimport`.

import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
  ...(Abréviation)...

public class ClassificationExample {
    public static void main(String[] args) throws IOException {

Chargez le fichier de données téléchargé avec CSVLoader et maintenez les données dans une classe appelée ListDataSource.

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(labelFactory);

String[] irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
ListDataSource irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"), "species", irisHeaders);

Ces données sont divisées en données d'entraînement et données de vérification à 7: 3.

TrainTestSplitter irisSplitter = new TrainTestSplitter<>(irisesSource, 0.7, 1L);

MutableDataset trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(irisSplitter.getTest());

Avec LogisticRegressionTrainer, vous pouvez apprendre avec la régression logistique.

Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);

En appelant la méthode LogisticRegressionTrainer.toString () de la classe, vous pouvez voir la valeur de l'hyperparamètre que vous utilisez comme suit:

LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)

Evaluez la correction des données de validation avec LabelEvaluator.evaluate ().

LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);

Le résultat de l'évaluation peut être trouvé en appelant la méthode toString () de la classe LabelEvaluation. Vous pouvez également trouver la matrice de confusion en appelant la méthode getConfusionMatrix ().

System.out.println(evaluation);
System.out.println(evaluation.getConfusionMatrix());

C'est la fin de l'apprentissage et de l'évaluation du modèle, mais le tutoriel décrit également comment obtenir la «Provenance» mentionnée ci-dessus.

ModelProvenance provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));

La sortie de ce code ressemble à ceci:

TrainTestSplitter(
	class-name = org.tribuo.evaluation.TrainTestSplitter
	source = CSVLoader(
			class-name = org.tribuo.data.csv.CSVLoader
			outputFactory = LabelFactory(
					class-name = org.tribuo.classification.LabelFactory
				)
			response-name = species
			separator = ,
			quote = "
			path = file:/home/tamura/git/tribuo-examples/bezdekIris.data
			file-modified-time = 2020-09-24T09:05:30+09:00
			resource-hash = 36F668D1CBC29A8C2C1128C5D2F0D400FA04ED4DC62D12246F44CE9360360CC0
		)
	train-proportion = 0.7
	seed = 1
	size = 150
	is-train = true
)

Il semble que vous puissiez voir quel fichier a été lu et comment il a été divisé en données d'entraînement et données de vérification.

application

Je voudrais ajouter / modifier un peu le code source et vérifier le fonctionnement.

Habituellement, une fois que vous disposez d'un modèle très précis, vous l'utilisez pour prédire des données inconnues. Tribuo peut faire des prédictions avec Model.predict (), résultant en un objet Prediction. Cet objet contient une probabilité de quel type d'iris il s'agit, comme suit:

Je voudrais dire: «Faisons une prédiction», mais comme il n'y a pas de données inconnues, je donnerai des données de vérification à cette méthode. En implémentant ce qui suit, vous pouvez identifier la ligne où la prédiction des données de vérification est incorrecte.

List<Example> data = testingDataset.getData(); //Données pour vérification
for (Example<Label> testingData : data) {
    Prediction<Label> predict = irisModel.predict(testingData); //Prédire les données pour vérification une par une
    String expectedResult = testingData.getOutput().getLabel(); //Bonne réponse
    String predictResult = predict.getOutput().getLabel(); //Résultat prévu
    if (!predictResult.equals(expectedResult)) {
        System.out.println("Expected result : " + expectedResult);
        System.out.println("Predicted result: " + predictResult);
        System.out.println(predict.getOutputScores());
    }
}

Le résultat de sortie est le suivant.

Expected result : Iris-virginica
Predicted result: Iris-versicolor
{Iris-versicolor=(Iris-versicolor,0.5732799760841581), Iris-virginica=(Iris-virginica,0.42629863727592165), Iris-setosa=(Iris-setosa,4.213866399202189E-4)}

La mauvaise réponse "virginica" a été jugée à 57,3%, et la bonne réponse "versicolor" a été jugée à 42,6%, on peut donc dire qu'un cas qui n'a pas pu prédire était regrettable.

Ensuite, essayez d'utiliser le XGBoost mentionné ci-dessus. Tout ce que vous avez à faire est de changer le Trainer de LogisticRegressionTrainer à XGBoostClassificationTrainer (bien que vous ayez également besoin de ʻimport`, bien sûr).

// Trainer<Label> trainer = new LogisticRegressionTrainer();
Trainer<Label> trainer = new XGBoostClassificationTrainer(2);

Le «2» donné au constructeur est le nombre d'arbres de décision, et ici la valeur minimale «2» est donnée.

Le résultat n'a pas changé et le taux de réponse correcte était de 97,8%.

Construire

Cependant, j'écrirai la construction et la méthode pour ceux qui veulent corriger les bogues et ajouter des fonctions ainsi que les utiliser.

La version Tribuo nécessite Java 8 ou version ultérieure et Maven 3.5 ou version ultérieure, alors vérifiez d'abord cela.

$ mvn -version
Apache Maven 3.6.3
Maven home: /usr/share/maven
Java version: 1.8.0_265, vendor: Private Build, runtime: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.4.0-47-generic", arch: "amd64", family: "unix"

À propos, cet article utilise Ubuntu 20.04 comme système d'exploitation. Pour construire, lancez simplement le mvn clean package.

$ mvn clean package
... (Omis) ...
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time:  03:54 min
[INFO] Finished at: 2020-09-23T16:17:34+09:00
[INFO] ------------------------------------------------------------------------

La construction a été achevée en moins de 4 minutes et environ 350 Mo d'espace disque ont été ajoutés. J'ai déjà essayé Deeplearning4j (https://qiita.com/tamura__246/items/3893ec292284c7128069), mais c'est beaucoup plus léger que cela (Deeplearning4j dispose de dizaines de Go d'espace libre après quelques heures de construction). C'était parti).

Impressions

Quand je l'ai touché légèrement et j'ai regardé le code source, j'ai eu l'impression que c'était "facile pour Java, mais manquant de fonctionnalités". Pour être honnête, vous pourriez penser: "Avec Python, vous pouvez faire diverses choses plus facilement."

C'est peut-être encore dans le futur, mais comme la vitesse de sélection est rapide dans ce domaine, il est un peu douteux que nous puissions survivre dans le futur. J'aurais aimé avoir une raison claire d'utiliser Tribuo ... (Il semble que les modèles appris dans la bibliothèque Python puissent être utilisés à partir de programmes Java utilisant Tribuo, il peut donc y avoir de telles utilisations). J'attends avec impatience les développements futurs.

référence

Site officiel de Tribuo

Recommended Posts

J'ai essayé le machine learning OSS d'Oracle "Tribuo"
J'ai essayé la machine Spring State
[Apprentissage automatique] J'ai essayé la détection d'objets avec Create ML [détection d'objets]
J'ai essayé de résumer l'apprentissage Java (1)
J'ai essayé l'apprentissage de la gestion qui fait gagner du temps avec Studyplus.
J'ai créé une application d'apprentissage automatique avec Dash (+ Docker) part3 ~ Practice ~
J'ai essayé Spring.
J'ai essayé de mettre Tomcat
J'ai essayé youtubeDataApi.
J'ai essayé de refactoriser ①
J'ai essayé FizzBuzz.
J'ai essayé JHipster 5.1
J'ai essayé l'extraction de texte (OCR) dans Ruby à l'aide de l'API Vision (modèle d'apprentissage automatique formé)
[J'ai essayé] Tutoriel de printemps
J'ai essayé d'exécuter Autoware
J'ai essayé d'utiliser Gson
J'ai essayé d'utiliser TestNG
J'ai essayé Spring Batch
J'ai essayé d'utiliser Galasa
J'ai essayé node-jt400 (Programmes)
J'ai essayé node-jt400 (exécuter)
J'ai essayé node-jt400 (Transactions)
Enfant orienté objet!? J'ai essayé le Deep Learning avec Java (édition d'essai)
J'ai créé une application d'apprentissage automatique avec Dash (+ Docker) part2 ~ Façon basique d'écrire Dash ~