[JAVA] J'ai touché Tribuo publié par Oracle. Document Tribuo - Introduction à la classification avec les iris

Tutoriel de classification

Dans ce didacticiel, nous vous montrerons comment utiliser le célèbre jeu de données d'iris de Fisher pour prédire les espèces d'iris à l'aide du modèle de classification de Tribuo (maintenant 2020, mais la démo est toujours 1936). J'utilise l'ensemble de données de l'année. Soyez assuré que la prochaine fois, j'utiliserai le MNIST des années 90). Ici, nous nous concentrons sur la régression logistique simple et étudions la source et les métadonnées des données que Tribuo stocke dans chaque modèle.

`` Configuration '' Je dois obtenir une copie de l'ensemble de données iris.

wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

Commencez par charger la bibliothèque JAR Tribuo requise. Ici, l'expérience de classification jar et la bibliothèque json interop jar sont utilisées pour lire et écrire les informations de preuve.

jars ./tribuo-classification-experiments-4.0.0-jar-with-dependencies.jar
%jars ./tribuo-json-4.0.0-jar-with-dependencies.jar
import java.nio.file.Paths;

Importez tout à partir du package org.tribuo de base, ainsi que d'un simple chargeur CSV et d'un package de classification. J'essaye de construire une régression logistique, donc j'en ai besoin aussi.

import org.tribuo.*;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;

Ces importations sont destinées au système d'historique.

import com.fasterxml.jackson.databind.*;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.config.json.*;

`` Lire des données '' Dans Tribuo, tous les types de prédiction sont associés à une implémentation OutputFactory qui vous permet de créer la sous-classe Output appropriée à partir de l'entrée. Puisque nous effectuons une classification multi-classes, nous utiliserons LabelFactory. Passez ensuite labelFactory à un simple CSVLoader pour charger toutes les colonnes dans DataSource.

var labelFactory = new LabelFactory();
var csvLoader = new CSVLoader<>(labelFactory);

Étant donné que la copie d'Ayame (iris) n'a pas d'en-tête de colonne, créez un en-tête et insérez-le dans la méthode de chargement avec le chemin et les variables à afficher (dans ce cas "espèce"). Ayame (iris) n'a pas de fractionnement d'entraînement / test prédéfini, nous allons donc utiliser 70% des données pour l'entraînement pour créer le fractionnement.

var irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
var irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"),"species",irisHeaders);
var irisSplitter = new TrainTestSplitter<>(irisesSource,0.7,1L);

Remplissez les sources de données d'entraînement et de test dans leurs ensembles de données respectifs. Ces jeux de données calculent toutes les métadonnées requises, telles que les zones d'entités et les zones de sortie. Il est préférable d'utiliser MutableDataset pour les ensembles de données d'entraînement. Maintenant que vous disposez du jeu de données, vous êtes prêt à entraîner le modèle.

var trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
var testingDataset = new MutableDataset<>(irisSplitter.getTest());
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",trainingDataset.size(),trainingDataset.getFeatureMap().size(),trainingDataset.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",testingDataset.size(),testingDataset.getFeatureMap().size(),testingDataset.getOutputInfo().size()));
Training data size = 105, number of features = 4, number of classes = 3
Testing data size = 45, number of features = 4, number of classes = 3

Training the model

Créons maintenant une instance du formateur et examinons les hyperparamètres par défaut. Pour un contrôle total de ces paramètres, vous pouvez utiliser directement le LinearSGD Trainer entièrement configurable.

Trainer<Label> trainer = new LogisticRegressionTrainer();
System.out.println(trainer.toString());
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)

Il s'agit d'un modèle linéaire avec perte logistique, formé avec AdaGrad en 5 époques.

Entraînons maintenant le modèle. Comme avec n'importe quel package, l'entraînement est très simple avec des algorithmes d'entraînement et des données d'entraînement.

Model<Label> irisModel = trainer.train(trainingDataset);

`` Évaluation du modèle '' Une fois que vous avez formé un modèle, vous devez évaluer dans quelle mesure il est formé. Pour ce faire, demandez à labelFactory (ou instanciez directement) quel est l'évaluateur approprié et transmettez le modèle et l'ensemble de données de test à l'évaluateur. Vous pouvez également transmettre une source de données au lieu de dataest. La classe LabelEvaluator implémente toutes les métriques de classification courantes, chacune pouvant être inspectée individuellement. LabelEvaluator.toString () produit un résumé joliment formaté des métriques.

var evaluator = new LabelEvaluator();
var evaluation = evaluator.evaluate(irisModel,testingDataset);
System.out.println(evaluation.toString());
Class                           n          tp          fn          fp      recall        prec       f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022

précision, rappel et F1 sont des indicateurs standard utilisés lors de l'évaluation des classificateurs multiclasses.

Vous pouvez également afficher la matrice de confusion.

System.out.println(evaluation.getConfusionMatrix().toString());
                   Iris-versicolor   Iris-virginica      Iris-setosa
Iris-versicolor                 16                0                0
Iris-virginica                   1               14                0
Iris-setosa    

`` Métadonnées du modèle ''

Tribuo garde une trace des fonctionnalités et des zones de sortie de chaque modèle construit. Cela vous permet d'exécuter des techniques similaires à LIME sans accéder aux données d'entraînement d'origine, ou d'ajouter une vérification pour voir si une entrée particulière se trouve dans le modèle d'entraînement.

Jetons un coup d'œil aux fonctionnalités du modèle Iris.

var featureMap = irisModel.getFeatureIDMap();
for (var v : featureMap) {
    System.out.println(v.toString());
    System.out.println();
}
CategoricalFeature(name=petalLength,id=0,count=105,map={1.2=1, 6.9=1, 3.6=1, 3.0=1, 1.7=4, 4.9=4, 4.4=3, 3.5=2, 5.9=2, 5.4=1, 4.0=4, 1.4=12, 4.5=4, 5.0=2, 5.5=3, 6.7=2, 3.7=1, 1.9=1, 6.0=2, 5.2=1, 5.7=2, 4.2=2, 4.7=2, 4.8=4, 1.6=4, 5.8=2, 3.8=1, 6.3=1, 3.3=1, 1.0=1, 5.6=4, 5.1=5, 4.6=3, 4.1=2, 1.5=9, 1.3=4, 3.9=3, 6.6=1, 6.1=2})

CategoricalFeature(name=petalWidth,id=1,count=105,map={2.0=3, 0.5=1, 1.2=3, 0.3=6, 1.6=2, 0.1=3, 0.4=5, 2.5=3, 2.3=4, 1.7=2, 1.1=3, 2.1=4, 0.6=1, 1.4=6, 1.0=5, 2.4=1, 1.8=12, 0.2=20, 1.9=4, 1.5=7, 1.3=8, 2.2=2})

CategoricalFeature(name=sepalLength,id=2,count=105,map={6.9=3, 6.4=3, 7.4=1, 4.9=4, 4.4=1, 5.9=3, 5.4=5, 7.2=3, 7.7=3, 5.0=8, 6.2=2, 5.5=5, 6.7=7, 6.0=3, 5.2=2, 6.5=3, 5.7=4, 4.7=2, 4.8=3, 5.8=4, 5.3=1, 6.8=3, 6.3=5, 7.3=1, 5.6=6, 5.1=7, 4.6=4, 7.6=1, 7.1=1, 6.6=2, 6.1=5})

CategoricalFeature(name=sepalWidth,id=3,count=105,map={2.0=1, 2.8=10, 3.6=4, 2.3=3, 2.5=5, 3.1=8, 3.8=4, 3.0=19, 2.6=4, 4.4=1, 3.3=4, 3.5=4, 2.4=2, 3.2=10, 2.9=5, 3.7=3, 3.4=6, 2.2=2, 3.9=2, 4.2=1, 2.7=7})

Vous pouvez voir un histogramme des quatre caractéristiques et leurs valeurs. Ces informations peuvent être utilisées pour échantillonner chaque fonctionnalité, créer des exemples candidats de variables explicatives locales telles que LIME et vérifier la plage. Les informations sur les fonctionnalités sont figées pendant l'apprentissage du modèle, donc si l'ensemble de fonctionnalités est clairsemé (comme c'est souvent le cas avec les problèmes de PNL), il peut également être utilisé pour voir combien de fonctionnalités se sont produites au cours de l'ensemble d'apprentissage. ..

`` Modèle de certificat ''

De nombreux types de modèles ML ont été déployés dans des applications modernes pour prendre en charge divers aspects de l'application. Cependant, la plupart des packages ML ne prennent pas en charge le suivi et la reconstruction de modèle. Dans Tribuo, chaque modèle suit ses performances. Vous pouvez voir comment il a été créé, quand il a été créé et quelles données sont impliquées. Ici, jetons un coup d'œil aux résultats réels des données du modèle d'iris. Par défaut, Tribuo affiche le certificat dans un format raisonnable et lisible par l'homme en utilisant la méthode toString () de chaque objet de certificat. Toutes les informations sont accessibles par programme.

var provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));
TrainTestSplitter(
	class-name = org.tribuo.evaluation.TrainTestSplitter
	source = CSVLoader(
			class-name = org.tribuo.data.csv.CSVLoader
			outputFactory = LabelFactory(
					class-name = org.tribuo.classification.LabelFactory
				)
			response-name = species
			separator = ,
			quote = "
			path = file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data
			file-modified-time = 1999-12-14T15:12:39-05:00
			resource-hash = 0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC
		)
	train-proportion = 0.7
	seed = 1
	size = 150
	is-train = true
)

Vous pouvez voir que le modèle est entraîné sur une source de données qui est divisée en deux, en utilisant un rapport de départ et de division aléatoire spécifique. La source de données d'origine est un fichier CSV, qui enregistre également l'heure de modification du fichier et le hachage SHA-256.

De même, vous pouvez découvrir l'algorithme de formation en regardant la source du stagiaire.

Ici, comme prévu, nous pouvons voir que notre modèle est formé en utilisant le LogisticRegressionTrainer avec AdaGrad comme algorithme de descente de gradient.

Si vous souhaitez conserver un autre enregistrement, vous pouvez extraire les réalisations du modèle et les enregistrer sous forme de fichier json (ou vous pouvez annuler les réalisations du modèle déployé).

ObjectMapper objMapper = new ObjectMapper();
objMapper.registerModule(new JsonProvenanceModule());
objMapper = objMapper.enable(SerializationFeature.INDENT_OUTPUT);

Bien que l'historique de json soit redondant, il offre un autre format de sérialisation lisible par l'homme.

System.out.println(jsonProvenance);
[ {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "linearsgdmodel-0",
  "object-class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
  "provenance-class" : "org.tribuo.provenance.ModelProvenance",
  "map" : {
    "instance-values" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.MapMarshalledProvenance",
      "map" : { }
    },
    "tribuo-version" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "tribuo-version",
      "value" : "4.0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "trainer" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "trainer",
      "value" : "logisticregressiontrainer-2",
      "provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "trained-at" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "trained-at",
      "value" : "2020-08-31T20:24:37.854775-04:00",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "dataset" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "dataset",
      "value" : "mutabledataset-1",
      "provenance-class" : "org.tribuo.provenance.DatasetProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "mutabledataset-1",
  "object-class-name" : "org.tribuo.MutableDataset",
  "provenance-class" : "org.tribuo.provenance.DatasetProvenance",
  "map" : {
    "num-features" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-features",
      "value" : "4",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "num-examples" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-examples",
      "value" : "105",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "num-outputs" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-outputs",
      "value" : "3",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "tribuo-version" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "tribuo-version",
      "value" : "4.0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "datasource" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "datasource",
      "value" : "traintestsplitter-3",
      "provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "transformations" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance",
      "list" : [ ]
    },
    "is-sequence" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-sequence",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-dense" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-dense",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.MutableDataset",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "logisticregressiontrainer-2",
  "object-class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
  "provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
  "map" : {
    "seed" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "seed",
      "value" : "12345",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "minibatchSize" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "minibatchSize",
      "value" : "1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "train-invocation-count" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "train-invocation-count",
      "value" : "0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-sequence" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-sequence",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "shuffle" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "shuffle",
      "value" : "true",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "epochs" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "epochs",
      "value" : "5",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "optimiser" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "optimiser",
      "value" : "adagrad-4",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "Trainer",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "objective" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "objective",
      "value" : "logmulticlass-5",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "loggingInterval" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "loggingInterval",
      "value" : "1000",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "traintestsplitter-3",
  "object-class-name" : "org.tribuo.evaluation.TrainTestSplitter",
  "provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
  "map" : {
    "train-proportion" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "train-proportion",
      "value" : "0.7",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "seed" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "seed",
      "value" : "1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "size" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "size",
      "value" : "150",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "source" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "source",
      "value" : "csvloader-6",
      "provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.evaluation.TrainTestSplitter",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-train" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-train",
      "value" : "true",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "adagrad-4",
  "object-class-name" : "org.tribuo.math.optimisers.AdaGrad",
  "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
  "map" : {
    "epsilon" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "epsilon",
      "value" : "0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "initialLearningRate" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "initialLearningRate",
      "value" : "1.0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "initialValue" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "initialValue",
      "value" : "0.0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "StochasticGradientOptimiser",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.math.optimisers.AdaGrad",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "logmulticlass-5",
  "object-class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
  "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
  "map" : {
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "LabelObjective",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "csvloader-6",
  "object-class-name" : "org.tribuo.data.csv.CSVLoader",
  "provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
  "map" : {
    "resource-hash" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "resource-hash",
      "value" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance",
      "additional" : "SHA256",
      "is-reference" : false
    },
    "path" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "path",
      "value" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "file-modified-time" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "file-modified-time",
      "value" : "1999-12-14T15:12:39-05:00",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "quote" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "quote",
      "value" : "\"",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "response-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "response-name",
      "value" : "species",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "outputFactory" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "outputFactory",
      "value" : "labelfactory-7",
      "provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "separator" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "separator",
      "value" : ",",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.data.csv.CSVLoader",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "labelfactory-7",
  "object-class-name" : "org.tribuo.classification.LabelFactory",
  "provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
  "map" : {
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.LabelFactory",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
} ]

Sinon, le modèle de certificat est également présent dans la sortie de Model.toString (), mais ce format n'est pas lisible par machine.

linear-sgd-model - Model(class-name=org.tribuo.classification.sgd.linear.LinearSGDModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=SplitDataSourceProvenance(className=org.tribuo.evaluation.TrainTestSplitter,innerSourceProvenance=CSV(class-name=org.tribuo.data.csv.CSVLoader,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),response-name=species,separator=,,quote=",path=file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data,file-modified-time=1999-12-14T15:12:39-05:00,resource-hash=SHA-256[0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC]),trainProportion=0.7,seed=1,size=150,isTrain=true),transformations=[],is-sequence=false,is-dense=false,num-examples=105,num-features=4,num-outputs=3,tribuo-version=4.0.1),trainer=Trainer(class-name=org.tribuo.classification.sgd.linear.LogisticRegressionTrainer,seed=12345,minibatchSize=1,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=1.0,initialValue=0.0,host-short-name=StochasticGradientOptimiser),objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),loggingInterval=1000,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2020-08-31T20:24:37.854775-04:00,instance-values={},tribuo-version=4.0.1)

L'évaluation comprend également un historique de l'enregistrement des performances du modèle ainsi que des performances des données de test. Vous utilisez un autre format de succès JSON. Cependant, c'est un peu moins précis. Au lieu de cela, c'est plus facile à lire. Ce format est bon pour référence, mais il a tout converti en une chaîne et ne peut pas être utilisé pour reconstruire l'objet de réussite d'origine.

String jsonEvaluationProvenance = objMapper.writeValueAsString(ProvenanceUtil.convertToMap(evaluation.getProvenance()));
System.out.println(jsonEvaluationProvenance);
{
  "tribuo-version" : "4.0.1",
  "dataset-provenance" : {
    "num-features" : "4",
    "num-examples" : "45",
    "num-outputs" : "3",
    "tribuo-version" : "4.0.1",
    "datasource" : {
      "train-proportion" : "0.7",
      "seed" : "1",
      "size" : "150",
      "source" : {
        "resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
        "path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
        "file-modified-time" : "1999-12-14T15:12:39-05:00",
        "quote" : "\"",
        "response-name" : "species",
        "outputFactory" : {
          "class-name" : "org.tribuo.classification.LabelFactory"
        },
        "separator" : ",",
        "class-name" : "org.tribuo.data.csv.CSVLoader"
      },
      "class-name" : "org.tribuo.evaluation.TrainTestSplitter",
      "is-train" : "false"
    },
    "transformations" : [ ],
    "is-sequence" : "false",
    "is-dense" : "false",
    "class-name" : "org.tribuo.MutableDataset"
  },
  "class-name" : "org.tribuo.provenance.EvaluationProvenance",
  "model-provenance" : {
    "instance-values" : { },
    "tribuo-version" : "4.0.1",
    "trainer" : {
      "seed" : "12345",
      "minibatchSize" : "1",
      "train-invocation-count" : "0",
      "is-sequence" : "false",
      "shuffle" : "true",
      "epochs" : "5",
      "optimiser" : {
        "epsilon" : "0.1",
        "initialLearningRate" : "1.0",
        "initialValue" : "0.0",
        "host-short-name" : "StochasticGradientOptimiser",
        "class-name" : "org.tribuo.math.optimisers.AdaGrad"
      },
      "host-short-name" : "Trainer",
      "class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
      "objective" : {
        "host-short-name" : "LabelObjective",
        "class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass"
      },
      "loggingInterval" : "1000"
    },
    "trained-at" : "2020-08-31T20:24:37.854775-04:00",
    "dataset" : {
      "num-features" : "4",
      "num-examples" : "105",
      "num-outputs" : "3",
      "tribuo-version" : "4.0.1",
      "datasource" : {
        "train-proportion" : "0.7",
        "seed" : "1",
        "size" : "150",
        "source" : {
          "resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
          "path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
          "file-modified-time" : "1999-12-14T15:12:39-05:00",
          "quote" : "\"",
          "response-name" : "species",
          "outputFactory" : {
            "class-name" : "org.tribuo.classification.LabelFactory"
          },
          "separator" : ",",
          "class-name" : "org.tribuo.data.csv.CSVLoader"
        },
        "class-name" : "org.tribuo.evaluation.TrainTestSplitter",
        "is-train" : "true"
      },
      "transformations" : [ ],
      "is-sequence" : "false",
      "is-dense" : "false",
      "class-name" : "org.tribuo.MutableDataset"
    },
    "class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel"
  }
}

Vous pouvez voir que ces informations de performances incluent tous les champs contenus dans les informations de performances du modèle, ainsi que les données de test, les données fractionnées et CSV.

Cet historique est utile pour suivre les modèles seuls, mais lorsqu'il est combiné avec le système de configuration décrit dans le didacticiel de configuration, il fournit un moyen puissant de reconstruire des modèles et des expériences, ainsi que tout modèle ML. Mais vous pouvez obtenir une reproductibilité presque parfaite.

`` Conclusion ''

Nous avons examiné le mécanisme de chargement csv de Tribuo, comment former un classificateur simple, comment évaluer un classificateur sur des données de test, ainsi que les métadonnées et les informations de performance stockées dans le modèle et les objets d'évaluation de Tribuo. ..

Recommended Posts

J'ai touché Tribuo publié par Oracle. Document Tribuo - Introduction à la classification avec les iris
J'ai touché Tribuo publié par Oracle. Document Tribuo --Une bibliothèque de prédiction Java (v4.0)
J'ai essayé Tribuo édité par Oracle. Tribuo --Une bibliothèque de prédiction Java (v4.0)
Je me suis rappelé Tribuo publié par Oracle. Tribuo --Une bibliothèque de prédiction Java (v4.0)