Grid Search http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html First, for grid search method, you need to select which parameters are used for the optimization and define parameter sets.
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
learner = RandomForestClassifier(random_state = 2)
n_estimators = [12, 24, 36, 48, 60]
min_samples_leaf = [1, 2, 4, 8, 16]
parameters = {'n_estimators': n_estimators, 'min_samples_leaf': min_samples_leaf}
In this case, AUC is used as a scorer. Thus, you need to create you own scorer for AUC.
def auc_scorer(target_score, prediction):
auc_value = roc_auc_score(prediction, target_score)
return auc_value
scorer = make_scorer(auc_scorer, greater_is_better=True)
Finally, you can define Grid Search Object.
grid_obj = GridSearchCV(learner, parameters, scorer)
Heat Map
http://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html
To create a heat map, you need to have 2 dimentional matrix at first. From Grid Search Object, you can retrieve all prediction results corresponding to grid search parameter set. In the example below, all result are put into scores
.
scores = grid_obj.cv_results_['mean_test_score'].reshape(len(n_estimators),len(min_samples_leaf))
Note: scores
contains the following array.
[[ 0.91803961 0.92444425 0.9264368 0.92730609 0.92808348]
[ 0.91263539 0.91757799 0.91892211 0.91957058 0.91950196]
[ 0.90143663 0.90590379 0.90669241 0.90751479 0.90758263]
[ 0.89168321 0.89370183 0.89414698 0.89497685 0.89506426]
[ 0.88276445 0.88386261 0.88380793 0.88408826 0.88448689]]
Then, you can use scores
for plotting a heat map.
plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('n_estimators')
plt.ylabel('min_samples_leaf')
plt.colorbar()
plt.xticks(np.arange(len(n_estimators)), n_estimators)
plt.yticks(np.arange(len(min_samples_leaf)), min_samples_leaf)
plt.title('Grid Search AUC Score')
plt.show()
Finally, you can plot a heat map like below.
Recommended Posts