I tried to summarize how to output multiple decision trees of RF in one SVG file using dtreeviz and svgutils.
I used the link below as it is and executed it for the time being. [Try dtreeviz for RandomForest](https://qiita.com/go50/item![Layout design.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/ 0/465379/d58f187f-3c25-82d6-e904-47dfc64147af.png) s/38c7757b444db3867b17)
from sklearn.datasets import load_iris
from sklearn import tree
from dtreeviz.trees import dtreeviz
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
clf = RandomForestClassifier(n_estimators=100 , max_depth = 2)
clf.fit(iris.data, iris.target)
estimators = clf.estimators_
viz = dtreeviz(
estimators[0],
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz.view()
The first model of the 100 decision trees can be visualized.
In the above program, one model will generate one svg file. By loop processing, all decision trees contained in RF were output as SVG. (Use viz.save () because it is troublesome to display all 100)
It is used to measure the processing time.
from tqdm import tqdm
for estimator in tqdm(estimators):
viz = dtreeviz(
estimator,
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz.save()
When I checked the Temp folder of the output destination, there was a problem that only one SVG file of the decision tree model was saved.
Apparently, the output file naming convention includes the process ID of the execution environment. It seems that the same file name is generated every time, the SVG file is updated each time, and only the last model is saved. Contents of site-packages \ dtreeviz \ tree.py
def save_svg(self):
"""Saves the current object as SVG file in the tmp directory and returns the filename"""
tmp = tempfile.gettempdir()
svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
self.save(svgfilename)
return svgfilename
Fixed the file naming convention to be generated using the runtime time. Import datatime into site-packages \ dtreeviz \ tree.py
from datetime import datetime
Fixed save_svg ()
def save_svg(self):
"""Saves the current object as SVG file in the tmp directory and returns the filename"""
tmp = tempfile.gettempdir()
#svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
now = datetime.now()
svgfilename = os.path.join(tmp, f"DTreeViz_{now:%Y%m%d_%H%M%S}.svg")
self.save(svgfilename)
return svgfilename
⇒Successful SVG output of all decision tree models
It's very annoying to look at the above files one by one. I integrated it into one file using svgutils and output it. (I can't find the site I referred to when using svgutils .. I will link it as soon as I rediscover it. )
It is designed to be as square as possible according to the number of decision trees & so that the layout can be corrected immediately even if the depth of the decision tree is changed.
Save 100 SVGs created in advance in a specific file and execute the following program
import svgutils.transform as sg
import glob
import math
import os
def join_svg(cell_w, cell_h):
SVG_file_dir = "./SVG_files"
svg_filename_list = glob.glob(SVG_file_dir + "/*.svg")
fig_tmp = sg.SVGFigure("128cm", "108cm")
N = len(svg_filename_list)
n_w_cells = int(math.sqrt(N))
i = 0
plot_list, txt_list = [], []
for target_svg_file in svg_filename_list:
print("i : {}".format(i))
pla_x = i % n_w_cells
pla_y = int(i / n_w_cells)
print("Plot position:[x,y] : {},{}".format(pla_x, pla_y))
print(target_svg_file)
fig_target = sg.fromfile(target_svg_file)
plot_target = fig_target.getroot()
plot_target.moveto(cell_w * pla_x, cell_h * pla_y, scale=1)
print("Model coordinates: {},{}".format(cell_w * pla_x, cell_h * pla_y))
plot_list.append(plot_target)
txt_target = sg.TextElement(25 + cell_w * pla_x, 20 + cell_h * pla_y,
str(i), size=12, weight="bold")
print("Text coordinates: {},{}".format(25 + cell_w * pla_x, 20 + cell_h * pla_y))
txt_list.append(txt_target)
print(i)
i += 1
fig_tmp.append(plot_list)
fig_tmp.append(txt_list)
ouput_dir = SVG_file_dir + "/output"
try :
fig_tmp.save(ouput_dir + "/RF.svg")
except FileNotFoundError:
os.mkdir(ouput_dir)
fig_tmp.save(ouput_dir + "/RF.svg")
join_svg(400, 300)
All files have been successfully combined successfully.
The file size is larger than expected (about 10M). It takes time to display even in chrome. In some cases, you may get an error due to insufficient memory to display if other apps are still running.
Recommended Posts