Confusion matrices can be easily visualized using scikit-learn, but sklearn.metrics.plot_confusion_matrix is * * estimator is required as an argument **. When I was searching for a method that does not require an estimator because it only visualizes, I found screarn.metrics.ConfusionMatrixDisplay. So I wrote the code easily.
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
data = load_breast_cancer()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
cm = confusion_matrix(y_pred=y_pred, y_true=y_test)
cmp = ConfusionMatrixDisplay(cm, display_labels=data.target_names)
cmp.plot(cmap=plt.cm.Blues)
The result looks like this.
Recommended Posts