J'ai essayé de résumer comment générer plusieurs arbres de décision RF dans un fichier SVG en utilisant dtreeviz et svgutils.
J'ai utilisé le lien ci-dessous tel quel et l'ai exécuté pour le moment. [Essayez dtreeviz pour RandomForest](https://qiita.com/go50/item![Layout design.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/ 0/465379 / d58f187f-3c25-82d6-e904-47dfc64147af.png) s/38c7757b444db3867b17)
from sklearn.datasets import load_iris
from sklearn import tree
from dtreeviz.trees import dtreeviz
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
clf = RandomForestClassifier(n_estimators=100 , max_depth = 2)
clf.fit(iris.data, iris.target)
estimators = clf.estimators_
viz = dtreeviz(
estimators[0],
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz.view()
Le premier modèle des 100 arbres de décision peut être visualisé.
Le programme ci-dessus générera un fichier svg avec un modèle. Par traitement en boucle, tous les arbres de décision contenus dans RF étaient sortis en SVG. (Utilisez viz.save () car il est difficile d'afficher les 100)
Il est utilisé pour mesurer le temps de traitement.
from tqdm import tqdm
for estimator in tqdm(estimators):
viz = dtreeviz(
estimator,
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz.save()
Lorsque j'ai vérifié le dossier Temp de la destination de sortie, le problème était que seul un fichier SVG du modèle d'arbre de décision était enregistré.
Apparemment, la convention de dénomination du fichier de sortie inclut l'ID de processus de l'environnement d'exécution. Il semble que le même nom de fichier soit généré à chaque fois, le fichier SVG est mis à jour à chaque fois et seul le dernier modèle est enregistré. Contenu des packages de site \ dtreeviz \ tree.py
def save_svg(self):
"""Saves the current object as SVG file in the tmp directory and returns the filename"""
tmp = tempfile.gettempdir()
svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
self.save(svgfilename)
return svgfilename
Correction de la convention de dénomination des fichiers à générer en utilisant le temps d'exécution. Importez les données dans les packages de site \ dtreeviz \ tree.py
from datetime import datetime
Correction de save_svg ()
def save_svg(self):
"""Saves the current object as SVG file in the tmp directory and returns the filename"""
tmp = tempfile.gettempdir()
#svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
now = datetime.now()
svgfilename = os.path.join(tmp, f"DTreeViz_{now:%Y%m%d_%H%M%S}.svg")
self.save(svgfilename)
return svgfilename
⇒ Sortie SVG réussie de tous les modèles d'arbres déterminés
C'est très ennuyeux de regarder les fichiers ci-dessus un par un. Je l'ai intégré dans un fichier à l'aide de svgutils et l'ai sorti. (Je ne trouve pas le site auquel j'ai fait référence lors de l'utilisation de svgutils .. Je le lierai dès que je le redécouvrirai. )
Il est conçu pour être aussi carré que possible en fonction du nombre d'arbres de décision et afin que la mise en page puisse être corrigée immédiatement même si la profondeur de l'arbre de décision est modifiée.
Enregistrez les 100 SVG créés à l'avance dans un fichier spécifique et exécutez le programme suivant
import svgutils.transform as sg
import glob
import math
import os
def join_svg(cell_w, cell_h):
SVG_file_dir = "./SVG_files"
svg_filename_list = glob.glob(SVG_file_dir + "/*.svg")
fig_tmp = sg.SVGFigure("128cm", "108cm")
N = len(svg_filename_list)
n_w_cells = int(math.sqrt(N))
i = 0
plot_list, txt_list = [], []
for target_svg_file in svg_filename_list:
print("i : {}".format(i))
pla_x = i % n_w_cells
pla_y = int(i / n_w_cells)
print("Position du tracé:[x,y] : {},{}".format(pla_x, pla_y))
print(target_svg_file)
fig_target = sg.fromfile(target_svg_file)
plot_target = fig_target.getroot()
plot_target.moveto(cell_w * pla_x, cell_h * pla_y, scale=1)
print("Coordonnées du modèle: {},{}".format(cell_w * pla_x, cell_h * pla_y))
plot_list.append(plot_target)
txt_target = sg.TextElement(25 + cell_w * pla_x, 20 + cell_h * pla_y,
str(i), size=12, weight="bold")
print("Coordonnées du texte: {},{}".format(25 + cell_w * pla_x, 20 + cell_h * pla_y))
txt_list.append(txt_target)
print(i)
i += 1
fig_tmp.append(plot_list)
fig_tmp.append(txt_list)
ouput_dir = SVG_file_dir + "/output"
try :
fig_tmp.save(ouput_dir + "/RF.svg")
except FileNotFoundError:
os.mkdir(ouput_dir)
fig_tmp.save(ouput_dir + "/RF.svg")
join_svg(400, 300)
Tous les fichiers ont été combinés avec succès.
La taille du fichier est plus grande que prévu (environ 10 Mo). L'affichage prend du temps, même en chrome. Dans certains cas, vous pouvez obtenir une erreur en raison d'une mémoire insuffisante à afficher si d'autres applications sont toujours en cours d'exécution.
Recommended Posts