L'autre jour, j'ai vu la nouvelle qu'Oracle avait publié une bibliothèque d'apprentissage automatique en Java en tant que source ouverte, alors je l'ai touchée légèrement.
En regardant cela, il semble que XGBoost, etc. puisse être utilisé en plus de l'algorithme général d'apprentissage automatique.
La page supérieure du site Web officiel présente les trois caractéristiques suivantes.
Provenance: les modèles, ensembles de données et évaluations de Tribuo ont un historique, vous pouvez donc suivre avec précision les paramètres, les méthodes de conversion de données, les fichiers, etc. utilisés pour les créer (*).
Type-safe: En utilisant Java, vous pouvez trouver des erreurs au moment de la compilation plutôt qu'en production.
--Interopérable: fournit une interface aux bibliothèques de machine learning populaires telles que XGBoost et Tensorflow. Il prend en charge le format d'échange de modèles ONNX et peut déployer des modèles intégrés dans d'autres packages et langages (tels que scikit-learn).
*: "Provenance" est une information qui montre comment un modèle ou un ensemble de données a été créé. Je ne sais pas si la traduction est «historique» et appropriée.
Il y avait un Tutoriel pour classer les iris, alors essayons ceci en premier. Dans ce tutoriel, nous allons apprendre les données des iris avec quatre caractéristiques (longueur et largeur des pétales et pétales), et créer et prédire un modèle qui les classe en trois types (versicolor, virginica, setosa).
Créez un projet Maven avec IntelliJ comme indiqué ci-dessous
Ajoutez ce qui suit à pom.xml
.
<dependencies>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-all</artifactId>
<version>4.0.0</version>
<type>pom</type>
</dependency>
</dependencies>
Téléchargez les données de fleur d'iris pour classer
wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data
Après cela, comme dans le didacticiel, lorsque vous créez une classe, la structure des répertoires sera la suivante.
L'exemple de projet que j'ai créé sera téléchargé sur mon GitHub, donc si vous voulez voir les détails du code source, veuillez vous référer à ici.
Ensuite, exécutez-le immédiatement. Cependant ... j'aurais dû l'implémenter selon le tutoriel, mais pour une raison quelconque, j'ai eu une erreur ...
Exception in thread "main" java.lang.IllegalArgumentException: On row 151 headers has 5 elements, current line has 1 elements.
at org.tribuo.data.csv.CSVIterator.zip(CSVIterator.java:168)
at org.tribuo.data.csv.CSVIterator.getRow(CSVIterator.java:188)
at org.tribuo.data.columnar.ColumnarIterator.hasNext(ColumnarIterator.java:114)
at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:249)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:161)
at ClassificationExample.main(ClassificationExample.java:21)
Process finished with exit code 1
Le message d'erreur dit «À la ligne 151 ...» », donc quand j'ai vérifié la 151e ligne du fichier de données, il n'y avait qu'un seul saut de ligne dans la dernière ligne du fichier de données ...
Supprimez la ligne vide en pensant «Ne faites pas ça» et exécutez à nouveau. Cette fois, elle a réussi, et 44 des 45 données de vérification étaient correctes. Le taux de réponse correcte était de 97,8%.
Class n tp fn fp recall prec f1
Iris-versicolor 16 16 0 1 1.000 0.941 0.970
Iris-virginica 15 14 1 0 0.933 1.000 0.966
Iris-setosa 14 14 0 0 1.000 1.000 1.000
Total 45 44 1 1
Accuracy 0.978
Micro Average 0.978 0.978 0.978
Macro Average 0.978 0.980 0.978
Balanced Error Rate 0.022
Iris-versicolor Iris-virginica Iris-setosa
Iris-versicolor 16 0 0
Iris-virginica 1 14 0
Iris-setosa 0 0 14
Comme vous pouvez le voir en lisant le tutoriel, je vais expliquer brièvement le code source.
J'ai créé une seule classe avec la méthode main ()
. Certaines classes, etc. requises pour la classification multi-classes sont ʻimport`.
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
...(Abréviation)...
public class ClassificationExample {
public static void main(String[] args) throws IOException {
Chargez le fichier de données téléchargé avec CSVLoader
et maintenez les données dans une classe appelée ListDataSource
.
LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(labelFactory);
String[] irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
ListDataSource irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"), "species", irisHeaders);
Ces données sont divisées en données d'entraînement et données de vérification à 7: 3.
TrainTestSplitter irisSplitter = new TrainTestSplitter<>(irisesSource, 0.7, 1L);
MutableDataset trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(irisSplitter.getTest());
Avec LogisticRegressionTrainer
, vous pouvez apprendre avec la régression logistique.
Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);
En appelant la méthode LogisticRegressionTrainer.toString ()
de la classe, vous pouvez voir la valeur de l'hyperparamètre que vous utilisez comme suit:
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)
Evaluez la correction des données de validation avec LabelEvaluator.evaluate ()
.
LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);
Le résultat de l'évaluation peut être trouvé en appelant la méthode toString ()
de la classe LabelEvaluation
. Vous pouvez également trouver la matrice de confusion en appelant la méthode getConfusionMatrix ()
.
System.out.println(evaluation);
System.out.println(evaluation.getConfusionMatrix());
C'est la fin de l'apprentissage et de l'évaluation du modèle, mais le tutoriel décrit également comment obtenir la «Provenance» mentionnée ci-dessus.
ModelProvenance provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));
La sortie de ce code ressemble à ceci:
TrainTestSplitter(
class-name = org.tribuo.evaluation.TrainTestSplitter
source = CSVLoader(
class-name = org.tribuo.data.csv.CSVLoader
outputFactory = LabelFactory(
class-name = org.tribuo.classification.LabelFactory
)
response-name = species
separator = ,
quote = "
path = file:/home/tamura/git/tribuo-examples/bezdekIris.data
file-modified-time = 2020-09-24T09:05:30+09:00
resource-hash = 36F668D1CBC29A8C2C1128C5D2F0D400FA04ED4DC62D12246F44CE9360360CC0
)
train-proportion = 0.7
seed = 1
size = 150
is-train = true
)
Il semble que vous puissiez voir quel fichier a été lu et comment il a été divisé en données d'entraînement et données de vérification.
Je voudrais ajouter / modifier un peu le code source et vérifier le fonctionnement.
Habituellement, une fois que vous disposez d'un modèle très précis, vous l'utilisez pour prédire des données inconnues. Tribuo peut faire des prédictions avec Model.predict ()
, résultant en un objet Prediction
. Cet objet contient une probabilité de quel type d'iris il s'agit, comme suit:
Je voudrais dire: «Faisons une prédiction», mais comme il n'y a pas de données inconnues, je donnerai des données de vérification à cette méthode. En implémentant ce qui suit, vous pouvez identifier la ligne où la prédiction des données de vérification est incorrecte.
List<Example> data = testingDataset.getData(); //Données pour vérification
for (Example<Label> testingData : data) {
Prediction<Label> predict = irisModel.predict(testingData); //Prédire les données pour vérification une par une
String expectedResult = testingData.getOutput().getLabel(); //Bonne réponse
String predictResult = predict.getOutput().getLabel(); //Résultat prévu
if (!predictResult.equals(expectedResult)) {
System.out.println("Expected result : " + expectedResult);
System.out.println("Predicted result: " + predictResult);
System.out.println(predict.getOutputScores());
}
}
Le résultat de sortie est le suivant.
Expected result : Iris-virginica
Predicted result: Iris-versicolor
{Iris-versicolor=(Iris-versicolor,0.5732799760841581), Iris-virginica=(Iris-virginica,0.42629863727592165), Iris-setosa=(Iris-setosa,4.213866399202189E-4)}
La mauvaise réponse "virginica" a été jugée à 57,3%, et la bonne réponse "versicolor" a été jugée à 42,6%, on peut donc dire qu'un cas qui n'a pas pu prédire était regrettable.
Ensuite, essayez d'utiliser le XGBoost mentionné ci-dessus. Tout ce que vous avez à faire est de changer le Trainer
de LogisticRegressionTrainer
à XGBoostClassificationTrainer
(bien que vous ayez également besoin de ʻimport`, bien sûr).
// Trainer<Label> trainer = new LogisticRegressionTrainer();
Trainer<Label> trainer = new XGBoostClassificationTrainer(2);
Le «2» donné au constructeur est le nombre d'arbres de décision, et ici la valeur minimale «2» est donnée.
Le résultat n'a pas changé et le taux de réponse correcte était de 97,8%.
Cependant, j'écrirai la construction et la méthode pour ceux qui veulent corriger les bogues et ajouter des fonctions ainsi que les utiliser.
La version Tribuo nécessite Java 8 ou version ultérieure et Maven 3.5 ou version ultérieure, alors vérifiez d'abord cela.
$ mvn -version
Apache Maven 3.6.3
Maven home: /usr/share/maven
Java version: 1.8.0_265, vendor: Private Build, runtime: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.4.0-47-generic", arch: "amd64", family: "unix"
À propos, cet article utilise Ubuntu 20.04 comme système d'exploitation. Pour construire, lancez simplement le mvn clean package
.
$ mvn clean package
... (Omis) ...
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 03:54 min
[INFO] Finished at: 2020-09-23T16:17:34+09:00
[INFO] ------------------------------------------------------------------------
La construction a été achevée en moins de 4 minutes et environ 350 Mo d'espace disque ont été ajoutés. J'ai déjà essayé Deeplearning4j (https://qiita.com/tamura__246/items/3893ec292284c7128069), mais c'est beaucoup plus léger que cela (Deeplearning4j dispose de dizaines de Go d'espace libre après quelques heures de construction). C'était parti).
Quand je l'ai touché légèrement et j'ai regardé le code source, j'ai eu l'impression que c'était "facile pour Java, mais manquant de fonctionnalités". Pour être honnête, vous pourriez penser: "Avec Python, vous pouvez faire diverses choses plus facilement."
C'est peut-être encore dans le futur, mais comme la vitesse de sélection est rapide dans ce domaine, il est un peu douteux que nous puissions survivre dans le futur. J'aurais aimé avoir une raison claire d'utiliser Tribuo ... (Il semble que les modèles appris dans la bibliothèque Python puissent être utilisés à partir de programmes Java utilisant Tribuo, il peut donc y avoir de telles utilisations). J'attends avec impatience les développements futurs.
Recommended Posts