Neulich sah ich die Nachricht, dass Oracle eine Bibliothek für maschinelles Lernen in Java als Open Source veröffentlicht hat, also berührte ich sie leicht.
Betrachtet man dies, so scheint es, dass XGBoost usw. zusätzlich zum allgemeinen Algorithmus des maschinellen Lernens verwendet werden kann.
Die oberste Seite der offiziellen Website enthält die folgenden drei Funktionen.
Provenienz: Die Modelle, Datensätze und Bewertungen von Tribuo haben einen Verlauf, sodass Sie die Parameter, Datenkonvertierungsmethoden, Dateien usw., mit denen sie erstellt wurden, genau verfolgen können (*).
Typensicher: Mit Java können Sie Fehler beim Kompilieren und nicht in der Produktion finden.
--Interoperable: Bietet eine Schnittstelle zu gängigen Bibliotheken für maschinelles Lernen wie XGBoost und Tensorflow. Es unterstützt das ONNX-Modellaustauschformat und kann Modelle bereitstellen, die in anderen Paketen und Sprachen (z. B. Scikit-Learn) erstellt wurden.
*: "Provenienz" ist eine Information, die zeigt, wie ein Modell oder ein Datensatz erstellt wurde. Ich weiß nicht, ob die Übersetzung "Geschichte" und angemessen ist.
Es gab ein Tutorial zum Klassifizieren von Iris. Versuchen wir es also zuerst. In diesem Tutorial lernen wir die Daten von Schwertlilien mit vier Merkmalen (Länge und Breite von Blütenblättern und Blütenblättern) und erstellen und prognostizieren ein Modell, das sie in drei Typen (versicolor, virginica, setosa) klassifiziert.
Erstellen Sie ein Maven-Projekt mit IntelliJ wie unten gezeigt
Fügen Sie Folgendes zu pom.xml
hinzu.
<dependencies>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-all</artifactId>
<version>4.0.0</version>
<type>pom</type>
</dependency>
</dependencies>
Laden Sie die Irisblumendaten herunter, um sie zu klassifizieren
wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data
Danach sieht die Verzeichnisstruktur wie im Lernprogramm wie folgt aus, wenn Sie eine Klasse erstellen.
Das von mir erstellte Beispielprojekt wird auf meinen GitHub hochgeladen. Wenn Sie also die Details des Quellcodes anzeigen möchten, lesen Sie bitte hier.
Führen Sie es dann sofort aus. Allerdings ... Ich hätte es gemäß dem Tutorial implementieren sollen, aber aus irgendeinem Grund habe ich einen Fehler bekommen ...
Exception in thread "main" java.lang.IllegalArgumentException: On row 151 headers has 5 elements, current line has 1 elements.
at org.tribuo.data.csv.CSVIterator.zip(CSVIterator.java:168)
at org.tribuo.data.csv.CSVIterator.getRow(CSVIterator.java:188)
at org.tribuo.data.columnar.ColumnarIterator.hasNext(ColumnarIterator.java:114)
at org.tribuo.data.csv.CSVLoader.innerLoadFromCSV(CSVLoader.java:249)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:238)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:209)
at org.tribuo.data.csv.CSVLoader.loadDataSource(CSVLoader.java:161)
at ClassificationExample.main(ClassificationExample.java:21)
Process finished with exit code 1
Die Fehlermeldung lautet "In Zeile 151 ...
". Als ich also die 151. Zeile der Datendatei überprüfte, gab es nur einen Zeilenumbruch in der letzten Zeile der Datendatei ...
Löschen Sie die leere Zeile, während Sie "Mach das nicht" denken, und führen Sie sie erneut aus. Diesmal war es erfolgreich und 44 von 45 Verifizierungsdaten waren korrekt. Die korrekte Rücklaufquote betrug 97,8%.
Class n tp fn fp recall prec f1
Iris-versicolor 16 16 0 1 1.000 0.941 0.970
Iris-virginica 15 14 1 0 0.933 1.000 0.966
Iris-setosa 14 14 0 0 1.000 1.000 1.000
Total 45 44 1 1
Accuracy 0.978
Micro Average 0.978 0.978 0.978
Macro Average 0.978 0.980 0.978
Balanced Error Rate 0.022
Iris-versicolor Iris-virginica Iris-setosa
Iris-versicolor 16 0 0
Iris-virginica 1 14 0
Iris-setosa 0 0 14
Wie Sie im Tutorial sehen können, werde ich den Quellcode kurz erläutern.
Ich habe nur eine Klasse mit der Methode main () erstellt. Einige Klassen usw., die für die Klassifizierung mehrerer Klassen erforderlich sind, sind "Import".
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
...(Abkürzung)...
public class ClassificationExample {
public static void main(String[] args) throws IOException {
Laden Sie die heruntergeladene Datendatei mit "CSVLoader" und halten Sie die Daten in einer Klasse namens "ListDataSource".
LabelFactory labelFactory = new LabelFactory();
CSVLoader csvLoader = new CSVLoader<>(labelFactory);
String[] irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
ListDataSource irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"), "species", irisHeaders);
Diese Daten werden 7: 3 in Trainingsdaten und Verifizierungsdaten unterteilt.
TrainTestSplitter irisSplitter = new TrainTestSplitter<>(irisesSource, 0.7, 1L);
MutableDataset trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
MutableDataset testingDataset = new MutableDataset<>(irisSplitter.getTest());
Mit LogisticRegressionTrainer
können Sie mit logistischer Regression lernen.
Trainer<Label> trainer = new LogisticRegressionTrainer();
Model<Label> irisModel = trainer.train(trainingDataset);
Durch Aufrufen der Methode "LogisticRegressionTrainer.toString ()" der Klasse können Sie den Wert des verwendeten Hyperparameters wie folgt anzeigen:
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)
Bewerten Sie mit "LabelEvaluator.evaluate ()", wie korrekt die Validierungsdaten sind.
LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(irisModel, testingDataset);
Das Ergebnis der Auswertung kann durch Aufrufen der Methode "toString ()" der Klasse "LabelEvaluation" ermittelt werden. Sie können die Verwirrungsmatrix auch finden, indem Sie die Methode getConfusionMatrix () aufrufen.
System.out.println(evaluation);
System.out.println(evaluation.getConfusionMatrix());
Dies ist das Ende des Lernens und der Bewertung des Modells, aber das Tutorial beschreibt auch, wie man die oben erwähnte "Provenienz" erhält.
ModelProvenance provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));
Die Ausgabe dieses Codes sieht folgendermaßen aus:
TrainTestSplitter(
class-name = org.tribuo.evaluation.TrainTestSplitter
source = CSVLoader(
class-name = org.tribuo.data.csv.CSVLoader
outputFactory = LabelFactory(
class-name = org.tribuo.classification.LabelFactory
)
response-name = species
separator = ,
quote = "
path = file:/home/tamura/git/tribuo-examples/bezdekIris.data
file-modified-time = 2020-09-24T09:05:30+09:00
resource-hash = 36F668D1CBC29A8C2C1128C5D2F0D400FA04ED4DC62D12246F44CE9360360CC0
)
train-proportion = 0.7
seed = 1
size = 150
is-train = true
)
Es scheint, dass Sie sehen können, welche Datei gelesen wurde und wie sie in Trainingsdaten und Verifizierungsdaten unterteilt wurde.
Ich möchte den Quellcode ein wenig hinzufügen / ändern und den Betrieb überprüfen.
Sobald Sie ein hochgenaues Modell haben, verwenden Sie es normalerweise, um unbekannte Daten vorherzusagen. Tribuo kann mit "Model.predict ()" Vorhersagen treffen, was zu einem "Vorhersage" -Objekt führt. Dieses Objekt enthält eine Wahrscheinlichkeit, um welche Art von Iris es sich handelt, wie folgt:
Ich möchte sagen: "Machen wir eine Vorhersage", aber da es keine unbekannten Daten gibt, werde ich dieser Methode Verifizierungsdaten geben. Indem Sie Folgendes implementieren, können Sie die Zeile identifizieren, in der die Vorhersage der Verifizierungsdaten falsch ist.
List<Example> data = testingDataset.getData(); //Daten zur Überprüfung
for (Example<Label> testingData : data) {
Prediction<Label> predict = irisModel.predict(testingData); //Prognostizieren Sie die Daten einzeln zur Überprüfung
String expectedResult = testingData.getOutput().getLabel(); //Richtige Antwort
String predictResult = predict.getOutput().getLabel(); //Voraussichtliches Ergebnis
if (!predictResult.equals(expectedResult)) {
System.out.println("Expected result : " + expectedResult);
System.out.println("Predicted result: " + predictResult);
System.out.println(predict.getOutputScores());
}
}
Das Ausgabeergebnis ist wie folgt.
Expected result : Iris-virginica
Predicted result: Iris-versicolor
{Iris-versicolor=(Iris-versicolor,0.5732799760841581), Iris-virginica=(Iris-virginica,0.42629863727592165), Iris-setosa=(Iris-setosa,4.213866399202189E-4)}
Die falsche Antwort "virginica" wurde mit 57,3% und die richtige Antwort "versicolor" mit 42,6% bewertet. Man kann also sagen, dass ein Fall, der nicht vorhergesagt werden konnte, bedauerlich war.
Versuchen Sie als nächstes, den oben genannten XGBoost zu verwenden. Alles was Sie tun müssen, ist den "Trainer" von "LogisticRegressionTrainer" in "XGBoostClassificationTrainer" zu ändern (obwohl Sie natürlich auch "Import" benötigen).
// Trainer<Label> trainer = new LogisticRegressionTrainer();
Trainer<Label> trainer = new XGBoostClassificationTrainer(2);
Die dem Konstruktor gegebene "2" ist die Anzahl der Entscheidungsbäume, und hier wird der Mindestwert "2" angegeben.
Das Ergebnis änderte sich nicht und die korrekte Antwortrate betrug 97,8%.
Ich werde jedoch den Build und die Methode für diejenigen schreiben, die Fehler beheben und Funktionen hinzufügen sowie verwenden möchten.
Tribuo Build erfordert Java 8 oder höher und Maven 3.5 oder höher. Überprüfen Sie dies zuerst.
$ mvn -version
Apache Maven 3.6.3
Maven home: /usr/share/maven
Java version: 1.8.0_265, vendor: Private Build, runtime: /usr/lib/jvm/java-8-openjdk-amd64/jre
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.4.0-47-generic", arch: "amd64", family: "unix"
Dieser Artikel verwendet übrigens Ubuntu 20.04 als Betriebssystem. Führen Sie zum Erstellen einfach das mvn clean package
aus.
$ mvn clean package
... (weggelassen) ...
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 03:54 min
[INFO] Finished at: 2020-09-23T16:17:34+09:00
[INFO] ------------------------------------------------------------------------
Der Build wurde in weniger als 4 Minuten abgeschlossen und es wurden ca. 350 MB Festplattenspeicher hinzugefügt. Ich habe Deeplearning4j schon einmal ausprobiert (https://qiita.com/tamura__246/items/3893ec292284c7128069), aber es ist viel leichter als das (Deeplearning4j verfügt nach einigen Stunden nach Abschluss des Builds über mehrere zehn GB freien Speicherplatz). Es war weg).
Als ich es leicht berührte und mir den Quellcode ansah, hatte ich den Eindruck, dass es "einfach für Java, aber ohne Funktionalität" sei. Um ehrlich zu sein, könnte man denken: "Mit Python können Sie verschiedene Dinge einfacher erledigen."
Es mag noch in der Zukunft liegen, aber da die Auswahlgeschwindigkeit in diesem Bereich hoch ist, ist es ein wenig zweifelhaft, ob wir in Zukunft überleben können. Ich wünschte, ich hätte einen klaren Grund, Tribuo zu verwenden ... (Es scheint, dass in der Python-Bibliothek erlernte Modelle aus Java-Programmen mit Tribuo verwendet werden können, daher kann es solche Verwendungen geben). Ich freue mich auf zukünftige Entwicklungen.
Recommended Posts