Évaluer la précision du modèle d'apprentissage par test croisé de scikit learn

Ce que j'ai fait

--Classez les données d'images numériques manuscrites par SVM

La source est ici.

Importation de bibliothèques et de données

Importez la bibliothèque de tests croisés "cross_validation". Les données utilisent des chiffres manuscrits.

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import svm, datasets, cross_validation
digits = datasets.load_digits()
X_digits = digits.data
y_digits = digits.target

Méthode de test croisé et réglage des paramètres

Les méthodes suivantes sont disponibles pour les tests croisés. --kFold (n, k): divise n échantillons de données en k lots. Utilisez un lot pour les tests et le lot restant (k-1) pour la formation. Modifiez l'ensemble de données utilisé pour le test et répétez le test k fois. --StratifiedKFold (y, k): diviser les données en k morceaux tout en préservant le rapport d'étiquettes dans l'ensemble de données divisé. --LeaveOneOut (n): Equivalent au cas de kFold avec k = n. Lorsque le nombre d'échantillons de données est petit. --LeaveOneLabelOut (labels): divise les données selon une étiquette donnée. Par exemple, lorsqu'il s'agit de données annuelles, lors de la réalisation de tests en les divisant en données annuelles.

Cette fois, j'utiliserai le KFold le plus simple. Le nombre de divisions était de 4. Comme je l'ai remarqué plus tard, dans KFold, si vous définissez la variable «huffle = true», il semble que l'ordre des données sera trié automatiquement.

np.random.seed(0) #Paramètre de départ aléatoire, il n'est pas nécessaire que ce soit 0
indices = np.random.permutation(len(X_digits))
X_digits = X_digits[indices] #Trier au hasard l'ordre des données
y_digits = y_digits[indices]
n_fold = 4 #Nombre de tests croisés
k_fold = cross_validation.KFold(n=len(X_digits),n_folds = n_fold)
# k_fold = cross_validation.KFold(n=len(X_digits),n_folds = n_fold, shuffle=true)
#Si tel est le cas, les quatre premières lignes ne sont pas nécessaires.

Expérimentez en changeant la variable C de SVM

Modifiez l'hyperparamètre C pour voir comment la valeur d'évaluation du modèle change. C est un paramètre qui détermine le nombre de faux positifs autorisés. Le noyau SVM était un noyau gaussien. Référence: article précédent "Reconnaître les nombres manuscrits avec SVM"

C_list = np.logspace(-8, 2, 11) # C
score = np.zeros((len(C_list),3))
tmp_train, tmp_test = list(), list()
# score_train, score_test = list(), list()
i = 0
for C in C_list:
    svc = svm.SVC(C=C, kernel='rbf', gamma=0.001)
    for train, test in k_fold:
        svc.fit(X_digits[train], y_digits[train])
        tmp_train.append(svc.score(X_digits[train],y_digits[train]))
        tmp_test.append(svc.score(X_digits[test],y_digits[test]))
        score[i,0] = C
        score[i,1] = sum(tmp_train) / len(tmp_train)
        score[i,2] = sum(tmp_test) / len(tmp_test)
        del tmp_train[:]
        del tmp_test[:]
    i = i + 1

Il est plus facile d'écrire si vous regardez simplement la valeur d'évaluation du test. Vous pouvez également spécifier le nombre de processeurs utilisés dans la variable n_jobs. -1 utilise tous les processeurs. cross_validation.cross_val_score(svc, X_digits, y_digits, cv=k_fold, n_jobs=-1) La valeur d'évaluation est sortie par tableau. array([ 0.98888889, 0.99109131, 0.99331849, 0.9844098 ])

Visualisez les résultats dans un graphique

Avec C comme axe horizontal, tracez la valeur d'évaluation pendant la formation et la valeur d'évaluation pendant le test. Si C est petit, la précision ne s'améliorera pas, probablement parce qu'un jugement erroné est trop autorisé.

xmin, xmax = score[:,0].min(), score[:,0].max()
ymin, ymax = score[:,1:2].min()-0.1, score[:,1:2].max()+0.1
plt.semilogx(score[:,0], score[:,1], c = "r", label = "train")
plt.semilogx(score[:,0], score[:,2], c = "b", label = "test")
plt.axis([xmin,xmax,ymin,ymax])
plt.legend(loc='upper left')
plt.xlabel('C')
plt.ylabel('score')
plt.show

image

Expérimentez en changeant la variable gamma SVM

Ensuite, fixez C à 100 et changez le gamma pour la même expérience. Plus le gamma est grand, plus les limites de classification sont complexes.

g_list = np.logspace(-8, 2, 11) # C
score = np.zeros((len(g_list),3))
tmp_train, tmp_test = list(), list()
i = 0
for gamma in g_list:
    svc = svm.SVC(C=100, gamma=gamma, kernel='rbf')
    for train, test in k_fold:
        svc.fit(X_digits[train], y_digits[train])
        tmp_train.append(svc.score(X_digits[train],y_digits[train]))
        tmp_test.append(svc.score(X_digits[test],y_digits[test]))
        score[i,0] = gamma
        score[i,1] = sum(tmp_train) / len(tmp_train)
        score[i,2] = sum(tmp_test) / len(tmp_test)
        del tmp_train[:]
        del tmp_test[:]
    i = i + 1

Voici les résultats. image À mesure que le gamma augmente, la précision pendant la formation et la précision pendant les tests augmentent, mais à partir d'environ 0,001, la précision pendant la formation ne change pas, mais la précision pendant les tests diminue. Il semble que le surapprentissage se produit en raison d'une trop grande complexité. Il s'avère que la définition des variables est importante.

Recommended Posts

Évaluer la précision du modèle d'apprentissage par test croisé de scikit learn
Essayez d'évaluer les performances du modèle d'apprentissage automatique / de régression
Essayez d'évaluer les performances du modèle d'apprentissage automatique / de classification
Évaluer les performances d'un modèle de régression simple à l'aide de la validation d'intersection LeaveOneOut
Notes d'apprentissage depuis le début de Python 1
Notes d'apprentissage depuis le début de Python 2
J'ai essayé d'appeler l'API de prédiction du modèle d'apprentissage automatique de WordPress
Apprenez les bases de la classification de documents par traitement du langage naturel, modèle de sujet
Optimisation des paramètres par recherche de grille de Scikit Learn
Mise en place d'un modèle de prédiction des taux de change (taux dollar-yen) par machine learning
Apprenez Nim avec Python (dès le début de l'année).
Comptez le nombre de paramètres dans le modèle d'apprentissage en profondeur
Othello ~ De la troisième ligne de "Implementation Deep Learning" (4) [Fin]
À en juger par l'image du chien Shiba en apprenant en profondeur si c'est mon enfant (1)
Un exemple de mécanisme qui renvoie une prédiction par HTTP à partir du résultat de l'apprentissage automatique
[Français] scikit-learn 0.18 Guide de l'utilisateur 3.1. Validation croisée: évaluer les performances de l'estimateur
Prédire la présence ou l'absence d'infidélité par l'apprentissage automatique
L'histoire du champ de modèle Django disparaissant de la classe
Renforcer l'apprentissage Apprendre d'aujourd'hui
Essayez de prédire le triplet de la course de bateaux en classant l'apprentissage
Déterminer s'il s'agit de mon enfant à partir de l'image du chien Shiba par apprentissage profond (3) Visualisation par Grad-CAM
Existence du point de vue de Python
Simulation Python du modèle épidémique (modèle Kermack-McKendrick)
Apprenez du concours code-Mercari gagnant
Validez le modèle d'entraînement avec Pylearn2
Apprenez les bases de Python ① Débutants élémentaires
À propos de l'ordre d'apprentissage des langages de programmation (de débutant à intermédiaire) Partie 2
Graphique de l'historique du nombre de couches de deep learning et du changement de précision
Traitement de la voix par apprentissage profond: identifions qui est l'acteur vocal à partir de la voix
Mémo d'apprentissage Python pour l'apprentissage automatique par Chainer jusqu'à la fin du chapitre 2
Déterminez l'authenticité des articles publiés par machine learning (API Google Prediction).