How to visualize the decision tree model of scikit-learn

Overview

An entry on how to visualize a model of a decision tree in scikit-learn. I use it a lot these days, so I'll write it instead of a memorandum & My cheat sheet. In this entry, the sample code is built with the Windows version of Python 3.5.2.

Preparing the environment

The components required to visualize the decision tree are:

Graphviz has different installation methods for each OS. I think scikit-learn is often included by default. On the other hand, pydotplus will need to be installed with pip.

Graphviz installation (Windows 7 environment)

Graphviz stands for Graph Visualization Software. It is a library that makes images written in the DOT language. Please read here for details.  http://www.graphviz.org/Documentation.php

The download page is as follows.   --Window version  http://www.graphviz.org/Download_windows.php   --RHEL, CentOS version  http://www.graphviz.org/Download_linux_rhel.php   --ubuntu version  http://www.graphviz.org/Download_linux_ubuntu.php   --Source version  http://www.graphviz.org/Download_source.php     Since it will be installed in a Windows environment, download the MSI file from the following page and execute it.  http://www.graphviz.org/Download_windows.php

When you run the downloaded MSI file, the following screen will be displayed first. Click Next to proceed with the screen. install_01.png

The version at the time of writing the entry (2017/09/03) is 2.38. Here, proceed with "Everyone" selected so that all users can use it. install_02.png

This message informs you that the installation is ready. Press "Next" to proceed. install_03.png

As the component installation progresses, the indicator gauge will fill up. Click Next when the indicators are completely filled. install_04.png

You have successfully installed Graphviz. Click Close to close the window. install_05.png

Then move on to installing Pydotplus.

pydotplus   A python module for working with the DOt language mentioned earlier. This time it's a Windows environment, so we'll work with Anaconda Prompt.

launch_anaconda_prompt.png

After launching Anaconda Prompt, run the command "pip install pydotplus".

(C:\Program Files\Anaconda3) C:\Users\usr********>pip install pydotplus
Collecting pydotplus
  Downloading pydotplus-2.0.2.tar.gz (278kB)
    100% |################################| 286kB 860kB/s
Requirement already satisfied: pyparsing>=2.0.1 in c:\program files\anaconda3\li
b\site-packages (from pydotplus)
Building wheels for collected packages: pydotplus
  Running setup.py bdist_wheel for pydotplus ... done
  Stored in directory: C:\Users\usr********\AppData\Local\pip\Cache\wheels\43\31\
48\e1d60511537b50a8ec28b130566d2fbbe4ac302b0def4baa48
Successfully built pydotplus
Installing collected packages: pydotplus
Successfully installed pydotplus-2.0.2

If successful, the above output will be output. If no error occurs and it is not interrupted, the pydotplus installation work is complete.

Setting environment variables

Then edit the environment variables to make pydotplus aware of the Graphviz installation path. First, find the location of the "bin" directory where graphviz is installed. You can check it by looking at the properties of "gvedit.exe" in the list of the start menu. install_path.png

Add this path ("C: \ Program Files (x86) \ Graphviz2.38 / bin") to the environment variable path. setup_env.png

After changing the path, restart the Python IDE (such as PyCharm).

Sample code for visualization of decision tree model

I made a RandomForest model using the familiar iris dataset, took out one decision tree model from that model, and visualized it (= output as a png image). The sample code is as follows. The print statement for debugging to understand the internal processing is left as it is.

u"""
Visualize the decision tree model.
Visualize a model of a decision tree using Graphviz.
It can be applied not only to decision trees but also to tree-structured models such as random forests.

"""

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import cross_val_score

#Packages needed to visualize the tree structure of the model
from sklearn import tree
import pydotplus as pdp

import pandas as pd
import numpy as np

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)

print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))

#Separate training data and test data

features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)

clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)
print("========================================================")
print("Prediction accuracy")
print(clf.score(df_test, label_test))

#Visualize one of the trees to try
estimators = clf.estimators_
file_name = "./tree_visualization.png "
dot_data = tree.export_graphviz(estimators[0], #Specify one decision tree object
                                out_file=None, #Since it passes dot language data to Graphviz without going through a file, None
                                filled=True, #When set to True, it will show in color which node was classified most at the time of branching.
                                rounded=True, #When set to True, the corners of the node are drawn round.
                                feature_names=features, #If this is not specified, the feature name will not be displayed on the chart.
                                class_names=iris.target_names, #If this is not specified, the classification name will not be displayed on the chart.
                                special_characters=True #Be able to handle special characters
                                )
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)

Let's take a look at each part. The following parts read the iris dataset and prepare the training data. The feature name is set in iris.feature_names. The objective variable (= iris type) is set in iris.target. However, iris.target is a number, which makes it unfriendly for humans to read. Therefore, using the kind name notation in iris.target_names, we set a human-readable (= human-redable) objective variable in df ['species'].

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)

print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))

Next is the code for the part that creates the Random Forest model. The df that stores the training data also contains the objective variable. It is necessary to separate the feature part and the objective variable and input them to the model. So, set the features part to features and the objective variable to label. Then, train_test_split divides the data into model training and test data. clf is set to the RandomForest object. The number of decision trees to be used is 150. (Argument: n_estimator = 150) After that, specify the training data and train the model with the fit () method.

features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)

clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)

And finally the visualization. The RandomForest object has a property called estimators_. estimators_ is a list of decision tree objects. Here we visualize the first decision tree object (estimators [0]) as a sample. png Output as an image file "tree_visualization.png ". tree.export_graphviz () is doing the visualization process. The explanation of the argument is described in the comment of the code. ** Please note that if you do not specify the argument properly, neither the feature name nor the classification name will be displayed. ** **

#Visualize one of the trees to try
estimators = clf.estimators_
file_name = "./tree_visualization.png "
dot_data = tree.export_graphviz(estimators[0], #Specify one decision tree object
                                out_file=None, #Since it passes dot language data to Graphviz without going through a file, None
                                filled=True, #When set to True, it will show in color which node was classified most at the time of branching.
                                rounded=True, #When set to True, the corners of the node are drawn round.
                                feature_names=features, #If this is not specified, the feature name will not be displayed on the chart.
                                class_names=iris.target_names, #If this is not specified, the classification name will not be displayed on the chart.
                                special_characters=True #Be able to handle special characters
                                )
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)

Then, the visualization of the decision tree as shown below is obtained as a png image.

tree_visualization.png

Recommended Posts

How to visualize the decision tree model of scikit-learn
Visualize the results of decision trees performed with Python scikit-learn
Python practice 100 knocks I tried to visualize the decision tree of Chapter 5 using graphviz
How to check the version of Django
[NNabla] How to add a quantization layer to the middle layer of a trained model
How to find the area of the Voronoi diagram
How to know the port number of the xinetd service
How to put Takoyaki Oishikunaru on the segment tree
How to get the number of digits in Python
[Blender] How to dynamically set the selection of EnumProperty
I tried to visualize the spacha information of VTuber
[Python] Summary of how to specify the color of the figure
How to use the model learned in Lobe in Python
How to hit the document of Magic Function (Line Magic)
How to access the global variable of the imported module
[Selenium] How to specify the relative path of chromedriver?
[Python] How to calculate the approximation formula of the same intercept 0 as Excel [scikit-learn] Memo
How to use the generator
Visualize the orbit of Hayabusa2
How to use the decorator
How to increase the axis
How to start the program
How to increase the processing speed of vertex position acquisition
[Ubuntu] How to delete the entire contents of a directory
Try to evaluate the performance of machine learning / regression model
How to test the attributes added by add_request_method of pyramid
Try to evaluate the performance of machine learning / classification model
I made a function to check the model of DCGAN
How to calculate the amount of calculation learned from ABC134-D
(Note) How to pass the path of your own module
How to summarize the results of FreeSurfer ~ aparc, aseg, wmparc ~
How to run the Export function of GCP Datastore automatically
How to increase the number of machine learning dataset images
[Python] I tried to visualize the follow relationship of Twitter
How to see the contents of the Jupyter notebook ipynb file
How to find the scaling factor of a biorthogonal wavelet
How to connect the contents of a list into a string
How to calculate the autocorrelation coefficient
Vertically visualize the amount corresponding to the vertices of networkx using Axes3D
How to use the zip function
How to handle multiple versions of CUDA in the same environment
I want to manually assign the training parameters of the [Pytorch] model
How to determine the existence of a selenium element in Python
How to change the log level of Azure SDK for Python
How to implement Java code in the background of RedHat (LinuxONE)
Creating a decision tree with scikit-learn
How to change the color of just the button pressed in Tkinter
How to get the ID of Type2Tag NXP NTAG213 with nfcpy
[EC2] How to install chrome and the contents of each command
Make the theme of Pythonista 3 like Monokai (how to make your own theme)
[Python] How to get the first and last days of the month
Summary of how to use pyenv-virtualenv
How to read the SNLI dataset
How to get the Python version
Machine learning ③ Summary of decision tree
How to check the memory size of a dictionary in Python
Python-Simulation of the Epidemic Model (Kermack-McKendrick Model)
[TensorFlow 2] How to check the contents of Tensor in graph mode
[Linux] How to disable the automatic update of the /etc/resolv.conf file (AmazonLinux2)
How to find the memory address of a Pandas dataframe value
[Python] How to import the library