Manuel de science des données Pythonの勉強中に思ったこと。
J'ai fait une carte thermique avec seaborn pour visualiser les erreurs de classification, mais n'est-ce pas la couleur avec le plus grand nombre dans l'ensemble? (Bien que le nombre d'échantillons devrait être le même pour toutes les catégories, des déséquilibres de données se produisent souvent)
Il serait préférable d'avoir une carte thermique qui montre le rapport de chaque élément au nombre total de lignes (= chaque classification). Je l'ai fait.
load_and_modelfitting.py
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
#À titre d'exemple, chargez cette fois l'image des caractères manuscrits en tant que tâche de classification
digits = load_digits()
X = digits.data
y = digits.target
#Divisé pour la formation et l'évaluation
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=0)
#Appliquer de manière appropriée les bayes naïfs gaussiennes à l'algorithme de classification
model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)
accuracy_score(ytest, y_model)
Heatmap la matrice de confusion habituelle et le tableau des ratios aux lignes
create_confmrx.py
#Tableau bidimensionnel de matrice de confusion ordinaire
mat = confusion_matrix(ytest, y_model)
#Un tableau à deux dimensions qui calcule le rapport au total de chaque ligne et arrondit le troisième chiffre.
mat_dec = np.round(mat / np.sum(mat, axis=1), decimals=2)
fig, axes = plt.subplots(1, 2, figsize=(10, 10))
kwargs = dict(square=True, annot=True, cbar=False, cmap='RdPu')
#Dessinez deux cartes thermiques
for i, dat in enumerate([mat, mat_dec]):
sns.heatmap(dat, **kwargs, ax=axes[i])
#Définir le titre du graphique, les étiquettes des axes X et Y
for ax, t in zip(axes, ['Real number', 'Percentage(per row)']):
plt.axes(ax)
plt.title(t)
plt.xlabel('predicted value')
plt.ylabel('true value')
La matrice de confusion est difficile à voir ...
Recommended Posts