Le concours kaggle [Google Cloud & NCAA® ML Competition 2020-NCAAW](https://www.kaggle.com/c/google-cloud-ncaa-march-madness-2020-division-] auquel j'ai participé en mars 2020. En raison de l'introduction de la fonction de suivi de mlflow dans 1-womens-tournoi), il était facile à utiliser, je vais donc le publier sous forme de mémorandum. Le contenu de la description décrit principalement comment introduire la fonction de tracking de mlflow et les points sur lesquels je suis tombé par hasard lors de son introduction.
mlflow est une plate-forme open source qui gère le cycle de vie du machine learning (pré-traitement-> apprentissage-> déploiement), et a trois fonctions principales. --Suivi: journalisation --Projets: Emballage --Modèles: support de déploiement Cette fois, je vais principalement aborder la manière d'introduire le suivi. Pour plus de détails sur les projets et les modèles, veuillez consulter ici.
Le suivi est une fonction qui enregistre chaque paramètre, index d'évaluation et résultat, fichier de sortie, etc. lors de la création d'un modèle d'apprentissage automatique. De plus, si vous mettez un projet dans git, vous pouvez gérer la version du code, mais je pensais que l'histoire s'étendrait aux projets lorsqu'il s'agira de l'introduire, donc je l'omettrai cette fois (la prochaine fois, je parlerai de projets) Je veux gérer ça quand je le fais).
mlflow peut être installé avec pip.
pip install mlflow
Définissez l'URI pour la journalisation (par défaut, il est créé directement sous le dossier lors de l'exécution).
Non seulement le répertoire local, mais également la base de données et le serveur HTTP peuvent être spécifiés pour l'URI.
Le nom du répertoire de destination de la journalisation doit être mlruns
(la raison sera expliquée plus tard).
import mlflow
mlflow.set_tracking_uri('./hoge/mlruns/')
L'expérience est créée arbitrairement par l'analyste pour chaque tâche du projet d'apprentissage automatique (par exemple, quantité de fonctionnalités, méthode d'apprentissage automatique, comparaison de paramètres, etc.).
#Si l'expérience n'existe pas, elle sera créée.
mlflow.set_experiment('compare_max_depth')
Enregistrons réellement.
with mlflow.start_run():
mlflow.log_param('param1', 1) #Paramètres
mlflow.log_metric('metric1', 0.1) #But
mlflow.log_artifact('./model.pickle') #Autres modèles, données, etc.
mlflow.search_runs() #Vous pouvez obtenir le contenu de la journalisation dans l'expérience
Il enregistre les paramètres, les scores, les modèles, etc. Veuillez vous référer au Document officiel pour les spécifications détaillées de chaque fonction.
Accédez au répertoire défini par URI. À ce stade, assurez-vous que le répertoire mlruns
est sous le contrôle (si le répertoire mlruns
n'existe pas, le répertoire mlruns
sera créé).
Démarrez le serveur local avec mlflow ui
.
$ cd ./hoge/
$ ls
mlruns
$ mlflow ui
Lorsque vous ouvrez http: //127.0.0.1: 5000
sur le navigateur, l'écran suivant s'affiche.
Il est également possible de comparer chaque paramètre.
Tips
tracking = mlflow.tracking.MlflowClient()
experiment = tracking.get_experiment_by_name('hoge')
print(experiment.experiment_id)
#Méthode 1:Obtenir la liste des expériences
tracking.list_experiments()
#Méthode 2:
tracking = mlflow.tracking.MlflowClient()
experimet = tracking.get_experiment('1') #passer l'identifiant de l'expérience
print(experimet.name)
tracking = mlflow.tracking.MlflowClient()
tracking.delete_experiment('1')
with mlflow.start_run():
run_id = mlflow.active_run().info.run_id
Si vous passez le run_id acquis au paramètre de start_run ()
, le journal de la cible run_id sera écrasé.
tracking = mlflow.tracking.MlflowClient()
tracking.delete_run(run_id)
#Si vous souhaitez enregistrer plusieurs paramètres en même temps, transmettez-le avec dict.
params = {
'test1': 1,
'test2': 2
}
metrics = {
'metric1': 0.1,
'metric2': 0.2
}
with mlflow.start_run():
mlflow.log_params(params)
mlflow.log_metrics(metrics)
download artifacts
tracking = mlflow.tracking.MlflowClient()
print(tracking.list_artifacts(run_id=run_id)) #Obtenez une liste d'artefacts
[<FileInfo: file_size=23, is_dir=False, path='model.pickle'>]
tracking.download_artifacts(run_id=run_id, path='model.pickle', dst_path='./')