Understand the decision tree and try it with scikit-learn.
Decision trees are classified by constructing a tree structure by setting thresholds for features that are important for classification.
The decision tree is a highly semantically interpretable classification model that allows you to know which features are important for classification and which one is classified at what threshold by visualizing the tree structure. It can also be used for regression.
There are several types of decision tree algorithms, but here we follow the CART algorithm used by scikit-learn.
In a tree structure, the start of the tree is called the root node and the end of the tree is called the leaf node, as shown in the figure below. In each node, the one above is called the parent node and the one below is called the child node.
In CART, starting from the root node, the threshold is set and divided by the feature amount that maximizes the information gain. Divide this until the leaf node is pure, that is, all the categories contained in the leaf node are the same.
However, pruning alleviates this, as splitting until the leaf nodes are pure would result in overfitting.
The decision tree learning is done by maximizing the information gain $ IG $ in the equation below.
IG(D_{parent}, f) = I(D_{parent}) - \frac{N_{left}}{N_{parent}} I(D_{left}) - \frac{N_{right}}{N_{parent}} I(D_{right})
Here, $ f $ is the feature quantity to be divided, $ D_ {parent} $ is the data contained in the parent node, $ D_ {left}, D_ {right} $ is the data of the left and right child nodes, $ N_ { parent} $ is the number of data in the parent node, $ N_ {left}, N_ {right} $ is the number of data in the left and right child nodes, and $ I $ is the impureness described below.
During training, the smaller the purity of the left and right child nodes in each feature, the larger the information gain, and the feature will be divided based on the set threshold.
The following three are typical indicators used for the evaluation of purity. Here, $ C_i (i = 1, 2, .., K) $ is $ K $ category, $ t $ is node, and $ P (C_i | t) $ is data of that category in a node. Represents the probability of being
The classification error $ I_E $ is not sensitive to node changes and is used for tree pruning as described below.
I_E = 1 - \max_i P(C_i | t)
Entropy $ I_H $ is 0 when all the data contained in the node belongs to the same category.
I_H = -\sum^K_{i=1} P(C_i | t) \ln P(C_i | t)
Gini $ I_G $ can be interpreted as an indicator that minimizes the probability of misclassification, which is 0 when all the data contained in the node belongs to the same category, similar to entropy.
I_G = 1 - \sum^K_{i=1} P^2 (C_i | t)
Each function is as shown below, and gini is the default in scikit-learn.
During training, the tree is deepened until the leaf nodes are pure, but if it is left as it is, it will be overfitting, so pruning is performed to alleviate this.
As an evaluation criterion for tree pruning, we define the reassignment error rate when training data is re-entered. The reassignment error rate $ R (t) $ at a node $ t $ is expressed by the following equation using the classification error $ I_E $ and the marginal probability $ P (t) $ of the node $ t $.
R(t) = \left( 1 - \max_i P(C_i | t) \right) \times P(t) \\
P(t) = \frac{N(t)}{N}
Where $ N (t) $ represents the number of data contained in node $ t $ and $ N $ represents the total number of training data.
Tree pruning removes tree branches based on this reassignment error rate. With scikit-learn, you can prun and secure the required number of nodes with the argument max_leaf_nodes.
-CPU Intel (R) Core (TM) i7-6700K 4.00GHz
・ Windows 10 Pro 1909 ・ Python 3.6.6 ・ Matplotlib 3.3.1 ・ Numpy 1.19.2 ・ Scikit-learn 0.23.2
The implemented program is published on GitHub.
decision_tree_clf.py
decision_tree_reg.py
I applied a decision tree to the breast cancer dataset I've been using so far. In the decision tree, the threshold value is determined for each feature amount, so there is no need to standardize the feature amount as preprocessing.
Accuracy 97.37%
Precision, Positive predictive value(PPV) 97.06%
Recall, Sensitivity, True positive rate(TPR) 98.51%
Specificity, True negative rate(TNR) 95.74%
Negative predictive value(NPV) 97.83%
F-Score 97.78%
scikit-learn provides a plot_tree function that visualizes the decision tree, making it easy to see the tree structure of the trained model.
You can also display the criteria for judgment at each node of the trained model on the command line as follows:
The binary tree structure has 9 nodes and has the following tree structure:
node=0 test node: go to node 1 if X[:, 27] <= 0.1423499956727028 else to node 2.
node=1 test node: go to node 3 if X[:, 23] <= 957.4500122070312 else to node 4.
node=2 test node: go to node 5 if X[:, 23] <= 729.5499877929688 else to node 6.
node=3 leaf node.
node=4 leaf node.
node=5 test node: go to node 7 if X[:, 4] <= 0.10830000042915344 else to node 8.
node=6 leaf node.
node=7 leaf node.
node=8 leaf node.
Rules used to predict sample 0:
decision id node 0 : (X_test[0, 27](= 0.2051) > 0.1423499956727028)
decision id node 2 : (X_test[0, 23](= 844.4) > 729.5499877929688)
The following samples [0, 1] share the node [0] in the tree
It is 11.11%% of all nodes.
The figure below shows the identification boundaries when performing a multiclass classification on an iris dataset.
The data of the regression problem is a sine wave plus a random number. It can be seen that deepening the tree improves expressiveness.
The figure below is an example applied to a multi-output regression problem.
Recommended Posts