[JAVA] Ich habe Tribuo von Oracle berührt. Dokument Tribuo --Intro-Klassifizierung mit Iris

Tutorial zur Klassifizierung

In diesem Tutorial zeigen wir Ihnen, wie Sie den berühmten Iris-Datensatz von Fisher verwenden, um Irisarten mithilfe des Tribuo-Klassifizierungsmodells vorherzusagen (jetzt im Jahr 2020, aber noch in der Demo im Jahr 1936). Ich verwende den Jahresdatensatz. Seien Sie versichert, dass ich das nächste Mal den MNIST der 90er verwenden werde. Hier konzentrieren wir uns auf die einfache logistische Regression und untersuchen die Quelle und Metadaten der Daten, die Tribuo in jedem Modell speichert.

Setup Ich brauche eine Kopie des Iris-Datensatzes.

wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

Laden Sie zunächst die erforderliche Tribuo-Glasbibliothek. Hier werden das Klassifikationsexperimentglas und die json-Interop-Glasbibliothek verwendet, um Beweisinformationen zu lesen und zu schreiben.

jars ./tribuo-classification-experiments-4.0.0-jar-with-dependencies.jar
%jars ./tribuo-json-4.0.0-jar-with-dependencies.jar
import java.nio.file.Paths;

Importieren Sie alles aus dem Basispaket org.tribuo sowie einem einfachen CSV-Lade- und Klassifizierungspaket. Ich versuche eine logistische Regression aufzubauen, also brauche ich das auch.

import org.tribuo.*;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;

Diese Importe sind für das Verlaufssystem.

import com.fasterxml.jackson.databind.*;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.config.json.*;

Daten lesen In Tribuo sind alle Vorhersagetypen einer OutputFactory-Implementierung zugeordnet, mit der Sie aus der Eingabe die entsprechende Ausgabe-Unterklasse erstellen können. Da wir eine Klassifizierung für mehrere Klassen durchführen, werden wir LabelFactory verwenden. Übergeben Sie dann labelFactory an einen einfachen CSVLoader, um alle Spalten in DataSource zu laden.

var labelFactory = new LabelFactory();
var csvLoader = new CSVLoader<>(labelFactory);

Da die Kopie von Ayame (Iris) keine Spaltenüberschrift hat, erstellt sie eine Überschrift und führt sie zusammen mit dem Pfad und den auszugebenden Variablen (in diesem Fall "Spezies") der Lademethode zu. Ayame (Iris) hat keine vordefinierte Trainings- / Testaufteilung, daher werden wir 70% der Daten für das Training verwenden, um die Aufteilung zu erstellen.

var irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
var irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"),"species",irisHeaders);
var irisSplitter = new TrainTestSplitter<>(irisesSource,0.7,1L);

Füllen Sie die Trainings- und Testdatenquellen in ihre jeweiligen Datensätze ein. Diese Datensätze berechnen alle erforderlichen Metadaten, z. B. Funktionsbereiche und Ausgabebereiche. Verwenden Sie MutableDataset am besten zum Trainieren von Datensätzen. Nachdem Sie den Datensatz haben, können Sie das Modell trainieren.

var trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
var testingDataset = new MutableDataset<>(irisSplitter.getTest());
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",trainingDataset.size(),trainingDataset.getFeatureMap().size(),trainingDataset.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",testingDataset.size(),testingDataset.getFeatureMap().size(),testingDataset.getOutputInfo().size()));
Training data size = 105, number of features = 4, number of classes = 3
Testing data size = 45, number of features = 4, number of classes = 3

Training the model

Erstellen wir nun eine Instanz des Trainers und sehen uns die Standard-Hyperparameter an. Zur vollständigen Steuerung dieser Parameter können Sie den vollständig konfigurierbaren LinearSGD-Trainer direkt verwenden.

Trainer<Label> trainer = new LogisticRegressionTrainer();
System.out.println(trainer.toString());
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)

Dies ist ein lineares Modell mit logistischem Verlust, das mit AdaGrad in 5 Epochen trainiert wurde.

Lassen Sie uns nun das Modell trainieren. Wie bei jedem Paket ist das Training mit Trainingsalgorithmen und Trainingsdaten sehr einfach.

Model<Label> irisModel = trainer.train(trainingDataset);

Modellbewertung Sobald Sie ein Modell trainiert haben, müssen Sie bewerten, wie gut es trainiert ist. Fragen Sie dazu labelFactory (oder instanziieren Sie direkt), welcher Evaluator geeignet ist, und übergeben Sie das Modell und den Testdatensatz an den Evaluator. Sie können auch eine Datenquelle anstelle von dataest übergeben. Die LabelEvaluator-Klasse implementiert alle gängigen Klassifizierungsmetriken, von denen jede einzeln überprüft werden kann. LabelEvaluator.toString () erstellt eine gut formatierte Zusammenfassung der Metriken.

var evaluator = new LabelEvaluator();
var evaluation = evaluator.evaluate(irisModel,testingDataset);
System.out.println(evaluation.toString());
Class                           n          tp          fn          fp      recall        prec       f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022

Präzision, Rückruf und F1 sind Standardindikatoren, die bei der Bewertung von Klassifikatoren für mehrere Klassen verwendet werden.

Sie können auch die Verwirrungsmatrix anzeigen.

System.out.println(evaluation.getConfusionMatrix().toString());
                   Iris-versicolor   Iris-virginica      Iris-setosa
Iris-versicolor                 16                0                0
Iris-virginica                   1               14                0
Iris-setosa    

Modellmetadaten

Tribuo verfolgt die Funktions- und Ausgabebereiche jedes gebauten Modells. Auf diese Weise können Sie LIME-ähnliche Techniken ausführen, ohne auf die ursprünglichen Trainingsdaten zuzugreifen, oder eine Überprüfung hinzufügen, um festzustellen, ob eine bestimmte Eingabe im Trainingsmodell enthalten ist.

Werfen wir einen Blick auf die Funktionsbereiche des Iris-Modells.

var featureMap = irisModel.getFeatureIDMap();
for (var v : featureMap) {
    System.out.println(v.toString());
    System.out.println();
}
CategoricalFeature(name=petalLength,id=0,count=105,map={1.2=1, 6.9=1, 3.6=1, 3.0=1, 1.7=4, 4.9=4, 4.4=3, 3.5=2, 5.9=2, 5.4=1, 4.0=4, 1.4=12, 4.5=4, 5.0=2, 5.5=3, 6.7=2, 3.7=1, 1.9=1, 6.0=2, 5.2=1, 5.7=2, 4.2=2, 4.7=2, 4.8=4, 1.6=4, 5.8=2, 3.8=1, 6.3=1, 3.3=1, 1.0=1, 5.6=4, 5.1=5, 4.6=3, 4.1=2, 1.5=9, 1.3=4, 3.9=3, 6.6=1, 6.1=2})

CategoricalFeature(name=petalWidth,id=1,count=105,map={2.0=3, 0.5=1, 1.2=3, 0.3=6, 1.6=2, 0.1=3, 0.4=5, 2.5=3, 2.3=4, 1.7=2, 1.1=3, 2.1=4, 0.6=1, 1.4=6, 1.0=5, 2.4=1, 1.8=12, 0.2=20, 1.9=4, 1.5=7, 1.3=8, 2.2=2})

CategoricalFeature(name=sepalLength,id=2,count=105,map={6.9=3, 6.4=3, 7.4=1, 4.9=4, 4.4=1, 5.9=3, 5.4=5, 7.2=3, 7.7=3, 5.0=8, 6.2=2, 5.5=5, 6.7=7, 6.0=3, 5.2=2, 6.5=3, 5.7=4, 4.7=2, 4.8=3, 5.8=4, 5.3=1, 6.8=3, 6.3=5, 7.3=1, 5.6=6, 5.1=7, 4.6=4, 7.6=1, 7.1=1, 6.6=2, 6.1=5})

CategoricalFeature(name=sepalWidth,id=3,count=105,map={2.0=1, 2.8=10, 3.6=4, 2.3=3, 2.5=5, 3.1=8, 3.8=4, 3.0=19, 2.6=4, 4.4=1, 3.3=4, 3.5=4, 2.4=2, 3.2=10, 2.9=5, 3.7=3, 3.4=6, 2.2=2, 3.9=2, 4.2=1, 2.7=7})

Sie können ein Histogramm der vier Features und ihrer Werte sehen. Diese Informationen können verwendet werden, um Beispiele für jedes Feature zu erstellen, Kandidatenbeispiele für lokale erklärende Variablen wie LIME zu erstellen und den Umfang zu ermitteln. Funktionsinformationen werden während des Modelltrainings eingefroren. Wenn der Funktionsumfang also spärlich ist (wie dies häufig bei NLP-Problemen der Fall ist), können Sie auch feststellen, wie viele Funktionen während des Trainingssatzes aufgetreten sind. ..

Modellzertifikat

In modernen Anwendungen wurden viele verschiedene Arten von ML-Modellen eingesetzt, um verschiedene Aspekte der Anwendung zu unterstützen. Die meisten ML-Pakete unterstützen jedoch nicht die Modellverfolgung und -wiederherstellung. In Tribuo verfolgt jedes Modell seine Leistung. Sie können sehen, wie es erstellt wurde, wann es erstellt wurde und welche Daten betroffen sind. Schauen wir uns hier die tatsächlichen Ergebnisse der Irismodelldaten an. Standardmäßig zeigt Tribuo das Zertifikat in einem angemessenen, für Menschen lesbaren Format an, indem die toString () -Methode jedes Zertifikatobjekts verwendet wird. Alle Informationen sind programmgesteuert zugänglich.

var provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));
TrainTestSplitter(
	class-name = org.tribuo.evaluation.TrainTestSplitter
	source = CSVLoader(
			class-name = org.tribuo.data.csv.CSVLoader
			outputFactory = LabelFactory(
					class-name = org.tribuo.classification.LabelFactory
				)
			response-name = species
			separator = ,
			quote = "
			path = file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data
			file-modified-time = 1999-12-14T15:12:39-05:00
			resource-hash = 0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC
		)
	train-proportion = 0.7
	seed = 1
	size = 150
	is-train = true
)

Sie können sehen, dass das Modell auf einer Datenquelle trainiert wird, die unter Verwendung eines bestimmten Verhältnisses von zufälligem Startwert und Aufteilung in zwei Teile geteilt wird. Die ursprüngliche Datenquelle ist eine CSV-Datei, die auch die Änderungszeit der Datei und den SHA-256-Hash aufzeichnet.

Ebenso können Sie den Trainingsalgorithmus anhand der Quelle des Auszubildenden ermitteln.

Hier können wir erwartungsgemäß sehen, dass unser Modell mit dem LogisticRegressionTrainer mit AdaGrad als Gradientenabstiegsalgorithmus trainiert wird.

Wenn Sie eine weitere Aufzeichnung führen möchten, können Sie die Erfolge aus dem Modell extrahieren und als JSON-Datei speichern (oder Sie können die Erfolge aus dem bereitgestellten Modell rückgängig machen).

ObjectMapper objMapper = new ObjectMapper();
objMapper.registerModule(new JsonProvenanceModule());
objMapper = objMapper.enable(SerializationFeature.INDENT_OUTPUT);

Obwohl die Erfolgsbilanz von json redundant ist, bietet es ein anderes für Menschen lesbares Serialisierungsformat.

System.out.println(jsonProvenance);
[ {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "linearsgdmodel-0",
  "object-class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
  "provenance-class" : "org.tribuo.provenance.ModelProvenance",
  "map" : {
    "instance-values" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.MapMarshalledProvenance",
      "map" : { }
    },
    "tribuo-version" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "tribuo-version",
      "value" : "4.0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "trainer" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "trainer",
      "value" : "logisticregressiontrainer-2",
      "provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "trained-at" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "trained-at",
      "value" : "2020-08-31T20:24:37.854775-04:00",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "dataset" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "dataset",
      "value" : "mutabledataset-1",
      "provenance-class" : "org.tribuo.provenance.DatasetProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "mutabledataset-1",
  "object-class-name" : "org.tribuo.MutableDataset",
  "provenance-class" : "org.tribuo.provenance.DatasetProvenance",
  "map" : {
    "num-features" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-features",
      "value" : "4",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "num-examples" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-examples",
      "value" : "105",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "num-outputs" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-outputs",
      "value" : "3",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "tribuo-version" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "tribuo-version",
      "value" : "4.0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "datasource" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "datasource",
      "value" : "traintestsplitter-3",
      "provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "transformations" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance",
      "list" : [ ]
    },
    "is-sequence" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-sequence",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-dense" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-dense",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.MutableDataset",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "logisticregressiontrainer-2",
  "object-class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
  "provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
  "map" : {
    "seed" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "seed",
      "value" : "12345",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "minibatchSize" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "minibatchSize",
      "value" : "1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "train-invocation-count" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "train-invocation-count",
      "value" : "0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-sequence" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-sequence",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "shuffle" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "shuffle",
      "value" : "true",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "epochs" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "epochs",
      "value" : "5",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "optimiser" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "optimiser",
      "value" : "adagrad-4",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "Trainer",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "objective" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "objective",
      "value" : "logmulticlass-5",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "loggingInterval" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "loggingInterval",
      "value" : "1000",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "traintestsplitter-3",
  "object-class-name" : "org.tribuo.evaluation.TrainTestSplitter",
  "provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
  "map" : {
    "train-proportion" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "train-proportion",
      "value" : "0.7",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "seed" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "seed",
      "value" : "1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "size" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "size",
      "value" : "150",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "source" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "source",
      "value" : "csvloader-6",
      "provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.evaluation.TrainTestSplitter",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-train" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-train",
      "value" : "true",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "adagrad-4",
  "object-class-name" : "org.tribuo.math.optimisers.AdaGrad",
  "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
  "map" : {
    "epsilon" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "epsilon",
      "value" : "0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "initialLearningRate" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "initialLearningRate",
      "value" : "1.0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "initialValue" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "initialValue",
      "value" : "0.0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "StochasticGradientOptimiser",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.math.optimisers.AdaGrad",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "logmulticlass-5",
  "object-class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
  "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
  "map" : {
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "LabelObjective",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "csvloader-6",
  "object-class-name" : "org.tribuo.data.csv.CSVLoader",
  "provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
  "map" : {
    "resource-hash" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "resource-hash",
      "value" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance",
      "additional" : "SHA256",
      "is-reference" : false
    },
    "path" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "path",
      "value" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "file-modified-time" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "file-modified-time",
      "value" : "1999-12-14T15:12:39-05:00",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "quote" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "quote",
      "value" : "\"",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "response-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "response-name",
      "value" : "species",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "outputFactory" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "outputFactory",
      "value" : "labelfactory-7",
      "provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "separator" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "separator",
      "value" : ",",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.data.csv.CSVLoader",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "labelfactory-7",
  "object-class-name" : "org.tribuo.classification.LabelFactory",
  "provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
  "map" : {
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.LabelFactory",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
} ]

Alternativ ist das Modellzertifikat auch in der Ausgabe von Model.toString () vorhanden, dieses Format ist jedoch nicht maschinenlesbar.

linear-sgd-model - Model(class-name=org.tribuo.classification.sgd.linear.LinearSGDModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=SplitDataSourceProvenance(className=org.tribuo.evaluation.TrainTestSplitter,innerSourceProvenance=CSV(class-name=org.tribuo.data.csv.CSVLoader,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),response-name=species,separator=,,quote=",path=file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data,file-modified-time=1999-12-14T15:12:39-05:00,resource-hash=SHA-256[0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC]),trainProportion=0.7,seed=1,size=150,isTrain=true),transformations=[],is-sequence=false,is-dense=false,num-examples=105,num-features=4,num-outputs=3,tribuo-version=4.0.1),trainer=Trainer(class-name=org.tribuo.classification.sgd.linear.LogisticRegressionTrainer,seed=12345,minibatchSize=1,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=1.0,initialValue=0.0,host-short-name=StochasticGradientOptimiser),objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),loggingInterval=1000,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2020-08-31T20:24:37.854775-04:00,instance-values={},tribuo-version=4.0.1)

Die Bewertung umfasst auch eine Erfolgsbilanz der Aufzeichnung der Modellleistung sowie der Testdatenleistung. Sie verwenden ein anderes Format für JSON-Erfolge. Dies ist jedoch etwas weniger genau. Stattdessen ist es einfacher zu lesen. Dieses Format dient als Referenz, hat jedoch alles in eine Zeichenfolge konvertiert und kann nicht zum erneuten Erstellen des ursprünglichen Leistungsobjekts verwendet werden.

String jsonEvaluationProvenance = objMapper.writeValueAsString(ProvenanceUtil.convertToMap(evaluation.getProvenance()));
System.out.println(jsonEvaluationProvenance);
{
  "tribuo-version" : "4.0.1",
  "dataset-provenance" : {
    "num-features" : "4",
    "num-examples" : "45",
    "num-outputs" : "3",
    "tribuo-version" : "4.0.1",
    "datasource" : {
      "train-proportion" : "0.7",
      "seed" : "1",
      "size" : "150",
      "source" : {
        "resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
        "path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
        "file-modified-time" : "1999-12-14T15:12:39-05:00",
        "quote" : "\"",
        "response-name" : "species",
        "outputFactory" : {
          "class-name" : "org.tribuo.classification.LabelFactory"
        },
        "separator" : ",",
        "class-name" : "org.tribuo.data.csv.CSVLoader"
      },
      "class-name" : "org.tribuo.evaluation.TrainTestSplitter",
      "is-train" : "false"
    },
    "transformations" : [ ],
    "is-sequence" : "false",
    "is-dense" : "false",
    "class-name" : "org.tribuo.MutableDataset"
  },
  "class-name" : "org.tribuo.provenance.EvaluationProvenance",
  "model-provenance" : {
    "instance-values" : { },
    "tribuo-version" : "4.0.1",
    "trainer" : {
      "seed" : "12345",
      "minibatchSize" : "1",
      "train-invocation-count" : "0",
      "is-sequence" : "false",
      "shuffle" : "true",
      "epochs" : "5",
      "optimiser" : {
        "epsilon" : "0.1",
        "initialLearningRate" : "1.0",
        "initialValue" : "0.0",
        "host-short-name" : "StochasticGradientOptimiser",
        "class-name" : "org.tribuo.math.optimisers.AdaGrad"
      },
      "host-short-name" : "Trainer",
      "class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
      "objective" : {
        "host-short-name" : "LabelObjective",
        "class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass"
      },
      "loggingInterval" : "1000"
    },
    "trained-at" : "2020-08-31T20:24:37.854775-04:00",
    "dataset" : {
      "num-features" : "4",
      "num-examples" : "105",
      "num-outputs" : "3",
      "tribuo-version" : "4.0.1",
      "datasource" : {
        "train-proportion" : "0.7",
        "seed" : "1",
        "size" : "150",
        "source" : {
          "resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
          "path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
          "file-modified-time" : "1999-12-14T15:12:39-05:00",
          "quote" : "\"",
          "response-name" : "species",
          "outputFactory" : {
            "class-name" : "org.tribuo.classification.LabelFactory"
          },
          "separator" : ",",
          "class-name" : "org.tribuo.data.csv.CSVLoader"
        },
        "class-name" : "org.tribuo.evaluation.TrainTestSplitter",
        "is-train" : "true"
      },
      "transformations" : [ ],
      "is-sequence" : "false",
      "is-dense" : "false",
      "class-name" : "org.tribuo.MutableDataset"
    },
    "class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel"
  }
}

Sie können sehen, dass diese Leistungsinformationen alle Felder enthalten, die in den Leistungsinformationen des Modells enthalten sind, sowie Testdaten, geteilte Daten und CSV.

Diese Erfolgsbilanz ist nützlich, um Modelle selbst zu verfolgen. In Kombination mit dem im Konfigurationstutorial beschriebenen Konfigurationssystem bietet sie jedoch eine leistungsstarke Möglichkeit, Modelle und Experimente sowie jedes ML-Modell zu rekonstruieren. Sie können jedoch eine nahezu perfekte Reproduzierbarkeit erzielen.

Fazit

Wir haben uns den CSV-Lademechanismus von Tribuo angesehen, wie man einen einfachen Klassifikator trainiert, wie man einen Klassifikator anhand von Testdaten bewertet und welche Metadaten und Leistungsinformationen in Tribuos Modell- und Bewertungsobjekten gespeichert sind. ..

Recommended Posts

Ich habe Tribuo von Oracle berührt. Dokument Tribuo --Intro-Klassifizierung mit Iris
Ich habe Tribuo von Oracle berührt. Document Tribuo - Eine Java-Vorhersagebibliothek (v4.0)
Ich habe Tribuo von Oracle ausprobiert. Tribuo - Eine Java-Vorhersagebibliothek (v4.0)
Ich erinnerte mich an Tribuo, das von Oracle veröffentlicht wurde. Tribuo - Eine Java-Vorhersagebibliothek (v4.0)