[JAVA] Ich habe das maschinelle Lernen von Oracle OSS "Tribuo" ausprobiert.

Einführung

Neulich sah ich die Nachricht, dass Oracle eine Bibliothek für maschinelles Lernen in Java als Open Source veröffentlicht hat, also berührte ich sie leicht.

Betrachtet man dies, so scheint es, dass XGBoost usw. zusätzlich zum allgemeinen Algorithmus des maschinellen Lernens verwendet werden kann.

HeroAnimation_V2.gif

Charakteristisch

Die oberste Seite der offiziellen Website enthält die folgenden drei Funktionen.

--Interoperable: Bietet eine Schnittstelle zu gängigen Bibliotheken für maschinelles Lernen wie XGBoost und Tensorflow. Es unterstützt das ONNX-Modellaustauschformat und kann Modelle bereitstellen, die in anderen Paketen und Sprachen (z. B. Scikit-Learn) erstellt wurden.

*: "Provenienz" ist eine Information, die zeigt, wie ein Modell oder ein Datensatz erstellt wurde. Ich weiß nicht, ob die Übersetzung "Geschichte" und angemessen ist.

Bewegen Sie sich vorerst

Es gab ein Tutorial zum Klassifizieren von Iris. Versuchen wir es also zuerst. In diesem Tutorial lernen wir die Daten von Schwertlilien mit vier Merkmalen (Länge und Breite von Blütenblättern und Blütenblättern) und erstellen und prognostizieren ein Modell, das sie in drei Typen (versicolor, virginica, setosa) klassifiziert.

iris.png

Erstellen Sie ein Maven-Projekt mit IntelliJ wie unten gezeigt

Screenshot from 2020-09-23 22-04-30.png

Fügen Sie Folgendes zu pom.xml hinzu.

<dependencies>
    <dependency>
        <groupId>org.tribuo</groupId>
        <artifactId>tribuo-all</artifactId>
        <version>4.0.0</version>
        <type>pom</type>
    </dependency>
</dependencies>

Laden Sie die Irisblumendaten herunter, um sie zu klassifizieren

wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

Danach sieht die Verzeichnisstruktur wie im Lernprogramm wie folgt aus, wenn Sie eine Klasse erstellen.

Screenshot from 2020-09-24 08-50-26.png

Das von mir erstellte Beispielprojekt wird auf meinen GitHub hochgeladen. Wenn Sie also die Details des Quellcodes anzeigen möchten, lesen Sie bitte hier.

Führen Sie es dann sofort aus. Allerdings ... Ich hätte es gemäß dem Tutorial implementieren sollen, aber aus irgendeinem Grund habe ich einen Fehler bekommen ...

Exception in thread "main" java.lang.IllegalArgumentException: On row 151 headers has 5 elements, current line has 1 elements.
	at org.tribuo.data.csv.CSVIterator.zip(CSVIterator.java:168)
	at org.tribuo.data.csv.CSVIterator.getRow(CSVIterator.java:188)
	at org.tribuo.data.columnar.ColumnarIterator.hasNext(ColumnarIterator.java:114)
	at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:249)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
	at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:161)
	at ClassificationExample.main(ClassificationExample.java:21)

Process finished with exit code 1

Die Fehlermeldung lautet "In Zeile 151 ...". Als ich also die 151. Zeile der Datendatei überprüfte, gab es nur einen Zeilenumbruch in der letzten Zeile der Datendatei ...

Screenshot from 2020-09-24 09-05-23.png

Löschen Sie die leere Zeile, während Sie "Mach das nicht" denken, und führen Sie sie erneut aus. Diesmal war es erfolgreich und 44 von 45 Verifizierungsdaten waren korrekt. Die korrekte Rücklaufquote betrug 97,8%.

Class                           n          tp          fn          fp      recall        prec          f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022
                   Iris-versicolor   Iris-virginica      Iris-setosa
Iris-versicolor                 16                0                0
Iris-virginica                   1               14                0
Iris-setosa                      0                0               14

Erklärung des Quellcodes

Wie Sie im Tutorial sehen können, werde ich den Quellcode kurz erläutern.

Ich habe nur eine Klasse mit der Methode main () erstellt. Einige Klassen usw., die für die Klassifizierung mehrerer Klassen erforderlich sind, sind "Import".

import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
  ...(Abkürzung)...

public class ClassificationExample {
    public static void main(String[] args) throws IOException {

Laden Sie die heruntergeladene Datendatei mit "CSVLoader" und halten Sie die Daten in einer Klasse namens "ListDataSource".

LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(labelFactory);

String[] irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
ListDataSource irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"), "species", irisHeaders);

Diese Daten werden 7: 3 in Trainingsdaten und Verifizierungsdaten unterteilt.

TrainTestSplitter irisSplitter = new TrainTestSplitter<>(irisesSource, 0.7, 1L);

MutableDataset trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(irisSplitter.getTest());

Mit LogisticRegressionTrainer können Sie mit logistischer Regression lernen.

Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);

Durch Aufrufen der Methode "LogisticRegressionTrainer.toString ()" der Klasse können Sie den Wert des verwendeten Hyperparameters wie folgt anzeigen:

LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)

Bewerten Sie mit "LabelEvaluator.evaluate ()", wie korrekt die Validierungsdaten sind.

LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);

Das Ergebnis der Auswertung kann durch Aufrufen der Methode "toString ()" der Klasse "LabelEvaluation" ermittelt werden. Sie können die Verwirrungsmatrix auch finden, indem Sie die Methode getConfusionMatrix () aufrufen.

System.out.println(evaluation);
System.out.println(evaluation.getConfusionMatrix());

Dies ist das Ende des Lernens und der Bewertung des Modells, aber das Tutorial beschreibt auch, wie man die oben erwähnte "Provenienz" erhält.

ModelProvenance provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));

Die Ausgabe dieses Codes sieht folgendermaßen aus:

TrainTestSplitter(
	class-name = org.tribuo.evaluation.TrainTestSplitter
	source = CSVLoader(
			class-name = org.tribuo.data.csv.CSVLoader
			outputFactory = LabelFactory(
					class-name = org.tribuo.classification.LabelFactory
				)
			response-name = species
			separator = ,
			quote = "
			path = file:/home/tamura/git/tribuo-examples/bezdekIris.data
			file-modified-time = 2020-09-24T09:05:30+09:00
			resource-hash = 36F668D1CBC29A8C2C1128C5D2F0D400FA04ED4DC62D12246F44CE9360360CC0
		)
	train-proportion = 0.7
	seed = 1
	size = 150
	is-train = true
)

Es scheint, dass Sie sehen können, welche Datei gelesen wurde und wie sie in Trainingsdaten und Verifizierungsdaten unterteilt wurde.

Anwendung

Ich möchte den Quellcode ein wenig hinzufügen / ändern und den Betrieb überprüfen.

Sobald Sie ein hochgenaues Modell haben, verwenden Sie es normalerweise, um unbekannte Daten vorherzusagen. Tribuo kann mit "Model.predict ()" Vorhersagen treffen, was zu einem "Vorhersage" -Objekt führt. Dieses Objekt enthält eine Wahrscheinlichkeit, um welche Art von Iris es sich handelt, wie folgt:

Ich möchte sagen: "Machen wir eine Vorhersage", aber da es keine unbekannten Daten gibt, werde ich dieser Methode Verifizierungsdaten geben. Indem Sie Folgendes implementieren, können Sie die Zeile identifizieren, in der die Vorhersage der Verifizierungsdaten falsch ist.

List<Example> data = testingDataset.getData(); //Daten zur Überprüfung
for (Example<Label> testingData : data) {
    Prediction<Label> predict = irisModel.predict(testingData); //Prognostizieren Sie die Daten einzeln zur Überprüfung
    String expectedResult = testingData.getOutput().getLabel(); //Richtige Antwort
    String predictResult = predict.getOutput().getLabel(); //Voraussichtliches Ergebnis
    if (!predictResult.equals(expectedResult)) {
        System.out.println("Expected result : " + expectedResult);
        System.out.println("Predicted result: " + predictResult);
        System.out.println(predict.getOutputScores());
    }
}

Das Ausgabeergebnis ist wie folgt.

Expected result : Iris-virginica
Predicted result: Iris-versicolor
{Iris-versicolor=(Iris-versicolor,0.5732799760841581), Iris-virginica=(Iris-virginica,0.42629863727592165), Iris-setosa=(Iris-setosa,4.213866399202189E-4)}

Die falsche Antwort "virginica" wurde mit 57,3% und die richtige Antwort "versicolor" mit 42,6% bewertet. Man kann also sagen, dass ein Fall, der nicht vorhergesagt werden konnte, bedauerlich war.

Versuchen Sie als nächstes, den oben genannten XGBoost zu verwenden. Alles was Sie tun müssen, ist den "Trainer" von "LogisticRegressionTrainer" in "XGBoostClassificationTrainer" zu ändern (obwohl Sie natürlich auch "Import" benötigen).

// Trainer<Label> trainer = new LogisticRegressionTrainer();
Trainer<Label> trainer = new XGBoostClassificationTrainer(2);

Die dem Konstruktor gegebene "2" ist die Anzahl der Entscheidungsbäume, und hier wird der Mindestwert "2" angegeben.

Das Ergebnis änderte sich nicht und die korrekte Antwortrate betrug 97,8%.

Bauen

Ich werde jedoch den Build und die Methode für diejenigen schreiben, die Fehler beheben und Funktionen hinzufügen sowie verwenden möchten.

Tribuo Build erfordert Java 8 oder höher und Maven 3.5 oder höher. Überprüfen Sie dies zuerst.

$ mvn -version
Apache Maven 3.6.3
Maven home: /usr/share/maven
Java version: 1.8.0_265, vendor: Private Build, runtime: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.4.0-47-generic", arch: "amd64", family: "unix"

Dieser Artikel verwendet übrigens Ubuntu 20.04 als Betriebssystem. Führen Sie zum Erstellen einfach das mvn clean package aus.

$ mvn clean package
... (weggelassen) ...
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time:  03:54 min
[INFO] Finished at: 2020-09-23T16:17:34+09:00
[INFO] ------------------------------------------------------------------------

Der Build wurde in weniger als 4 Minuten abgeschlossen und es wurden ca. 350 MB Festplattenspeicher hinzugefügt. Ich habe Deeplearning4j schon einmal ausprobiert (https://qiita.com/tamura__246/items/3893ec292284c7128069), aber es ist viel leichter als das (Deeplearning4j verfügt nach einigen Stunden nach Abschluss des Builds über mehrere zehn GB freien Speicherplatz). Es war weg).

Impressionen

Als ich es leicht berührte und mir den Quellcode ansah, hatte ich den Eindruck, dass es "einfach für Java, aber ohne Funktionalität" sei. Um ehrlich zu sein, könnte man denken: "Mit Python können Sie verschiedene Dinge einfacher erledigen."

Es mag noch in der Zukunft liegen, aber da die Auswahlgeschwindigkeit in diesem Bereich hoch ist, ist es ein wenig zweifelhaft, ob wir in Zukunft überleben können. Ich wünschte, ich hätte einen klaren Grund, Tribuo zu verwenden ... (Es scheint, dass in der Python-Bibliothek erlernte Modelle aus Java-Programmen mit Tribuo verwendet werden können, daher kann es solche Verwendungen geben). Ich freue mich auf zukünftige Entwicklungen.

Referenz

Offizielle Tribuo-Website

Recommended Posts

Ich habe das maschinelle Lernen von Oracle OSS "Tribuo" ausprobiert.
Ich habe Spring State Machine ausprobiert
[Maschinelles Lernen] Ich habe die Objekterkennung mit Create ML [Objekterkennung] ausprobiert.
Ich habe versucht, das Java-Lernen zusammenzufassen (1)
Ich habe mit Studyplus zeitsparendes Management-Lernen versucht.
Ich habe eine App für maschinelles Lernen mit Dash (+ Docker) Teil 3 ~ Übung ~ erstellt
Ich habe es mit Spring versucht.
Ich habe versucht, Tomcat zu setzen
Ich habe youtubeDataApi ausprobiert.
Ich habe versucht, ① umzugestalten
Ich habe FizzBuzz ausprobiert.
Ich habe JHipster 5.1 ausprobiert
Ich habe versucht, Text in Ruby mithilfe der Vision API (trainiertes Modell für maschinelles Lernen) zu extrahieren.
[Ich habe es versucht] Spring Tutorial
Ich habe versucht, Autoware auszuführen
Ich habe versucht, Gson zu benutzen
Ich habe versucht, TestNG zu verwenden
Ich habe Spring Batch ausprobiert
Ich habe versucht, Galasa zu benutzen
Ich habe versucht, node-jt400 (Programme)
Ich habe versucht, node-jt400 (ausführen)
Ich habe versucht, node-jt400 (Transaktionen)
Objektorientiertes Kind !? Ich habe Deep Learning mit Java ausprobiert (Testversion)
Ich habe eine App für maschinelles Lernen mit Dash (+ Docker) Teil 2 ~ Grundlegende Schreibweise für Dash ~ erstellt