Draw a CNN diagram in Python

ConvNet Drawer (Added on January 4, 2018)

** The following tools have been released and we recommend using them. ** ** I made a tool that illustrates the architecture when I define a convolutional neural network like Keras

When I define a model in a notation like Keras's Sequential model, I have created a tool that nicely illustrates the architecture. It may be a dependent library because it is a tool that only outputs text. https://github.com/yu4u/convnet-drawer

Overview

Use Python + pydot + Graphviz to draw a diagram of the CNN architecture. My motivation was to look at https://github.com/jettan/tikz_cnn and want to draw a similar diagram in Python instead of TeX.

Preparation

Install pydotplus and graphviz. I'm using conda, but I think pip is fine (unverified).

conda install -c conda-forge pydotplus
conda install graphviz

Prepare an appropriate dot file, load it with pydotplus, save the image, and display the image. (The image is displayed on Jupyter. Please edit as appropriate.)

drawCNN.py


import pydotplus
from IPython.display import Image

graph = pydotplus.graphviz.graph_from_dot_file('dot/pytorchainer.dot')
graph.write_png('img/pytorchainer.png')
Image(graph.create_png())

pytorchainer.dot


digraph G {

Python [shape=box]
Torch

Chainer -> "Chainer v2"
Chainer -> ChainerMN

Python -> PyTorch
Torch -> PyTorch
Chainer -> PyTorch

PyTorch -> PyTorChainer
"Chainer v2" -> PyTorChainer
ChainerMN -> PyTorChainer

This figure is fiction.[shape=plaintext]

}

pytorchainer.png Now you are ready to draw the dot file from Python.

Please refer to the following for the specifications of dot language and PyDot Plus. Summary of how to draw a graph in Graphviz and dot language PyDotPlus API Reference

CNN drawing

Now let's draw a diagram of the CNN architecture. That said, all you have to do is add layers (and arrows) written in dot language. Below, the magic number for position adjustment is dancing, but please forgive me.

drawCNN.py


class CNNDot():
    def __init__(self):
        self.layer_id = 0
        self.arrow_id = 0

    def get_layer_str(self, size, channels, xoffset=0.0, yoffset=0.0, fillcolor='white', caption=''):
        width = size * 0.5
        height = size
        x = xoffset
        y = height * 0.5 + yoffset
        x_caption = x - width * 0.25
        y_caption = -y - 0.7
        
        layer_str = """
          layer{} [
              shape=polygon, sides=4, skew=-2, orientation=90,
              label="", style=filled, fixedsize=true, fillcolor="{}",
              width={}, height={}, pos="{},{}!"
          ]
        """.format(self.layer_id, fillcolor, width, height, x, y)

        if caption != '':
            layer_str += """
              layer_caption{} [
                  shape=plaintext, label="{}", fixedsize=true, fontsize=24,
                  pos="{},{}!"
              ]
            """.format(self.layer_id, caption, x_caption, y_caption)

        self.layer_id += 1
        return layer_str

    def get_arrow_str(self, xmin, ymin, xmax, ymax):
        arrow_str = """
            arrow{0}_tail [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{1},{2}!"
            ]
            arrow{0}_head [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{3},{4}!"
            ]
            arrow{0}_tail -> arrow{0}_head
        """.format(self.arrow_id, xmin, ymin, xmax, ymax)
        self.arrow_id += 1
        return arrow_str
        
cnndot = CNNDot()
# layers
graph_data_main = cnndot.get_layer_str(3.0, 0, -1.00, fillcolor='gray') # input
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.00, caption='conv') # encoder begin
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.50)
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.25, caption='conv')
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 2.50, caption='conv')
graph_data_main += cnndot.get_layer_str(2.0, 0, 3.00)
graph_data_main += cnndot.get_layer_str(1.5, 0, 3.75, caption='conv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 4.25)
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.00, caption='conv')
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.50)
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.25, caption='deconv') # decoder begin
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.75)
graph_data_main += cnndot.get_layer_str(1.5, 0, 7.50, caption='deconv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 8.00)
graph_data_main += cnndot.get_layer_str(2.0, 0, 8.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 9.25)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.00)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.50)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.25)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.75)
graph_data_main += cnndot.get_layer_str(3.0, 0, 12.75, fillcolor='#FF8080') # output

# arrows
graph_data_main += cnndot.get_arrow_str(0.50, 3.0*1.2, 11.25-0.22, 3.0*1.2)
graph_data_main += cnndot.get_arrow_str(1.75, 2.5*1.2, 10.00-0.20, 2.5*1.2)
graph_data_main += cnndot.get_arrow_str(3.00, 2.0*1.2, 8.75-0.18, 2.0*1.2)
graph_data_main += cnndot.get_arrow_str(4.25, 1.5*1.2, 7.50-0.16, 1.5*1.2)
graph_data_main += cnndot.get_arrow_str(5.50, 1.0*1.2, 6.25-0.14, 1.0*1.2)

graph_data_setting = 'graph[ layout = neato, size="16,8"]'
graph_data = 'digraph G {{ \n{}\n{}\n }}'.format(graph_data_setting, graph_data_main)
graph = pydotplus.graphviz.graph_from_dot_data(graph_data)

# save and show image
graph.write_png('img/encoder-decoder.png')
Image(graph.create_png())

For this code, you should see a figure like the one below. (It is a specification that each layer is thin. If you stretch the sides, you should be able to draw a rectangular parallelepiped.)

encoder-decoder.png

Impressions

Inception V3 drawing (2017/4/30 postscript)

It's different from the above code, but I tried to draw a Keras model (InceptionV3). The rectangular parallelepiped is drawn and pasted with svg write.

inceptionv3.png

Recommended Posts

Draw a CNN diagram in Python
Draw a heart in Python
Draw a scatterplot matrix in python
Draw a heart in Python Part 2 (SymPy)
Draw a tree in Python 3 using graphviz
Draw graph in python
Draw a graph of a quadratic function in Python
[Python] How to draw a histogram in Matplotlib
Create a function in Python
Create a dictionary in Python
Draw Poincare's disk in Python
Draw "Draw Ferns Programmatically" in Python
Make a bookmarklet in Python
Draw implicit function in python
[Python] Draw a Qiita tag relationship diagram with NetworkX
Draw Sine Waves in Blender Python
Maybe in a python (original title: Maybe in Python)
[python] Manage functions in a list
Create a DI Container in Python
Draw knots interactively in Plotly (Python)
ABC166 in Python A ~ C problem
Write A * (A-star) algorithm in Python
Create a binary file in Python
Solve ABC036 A ~ C in Python
Write a pie chart in Python
Write a vim plugin in Python
Write a depth-first search in Python
Implementing a simple algorithm in Python 2
Create a Kubernetes Operator in Python
Solve ABC037 A ~ C in Python
Run a simple algorithm in Python
Create a random string in Python
Schedule a Zoom meeting in Python
When writing a program in Python
Draw a watercolor illusion with edge detection in Python3 and openCV3
Spiral book in Python! Python with a spiral book! (Chapter 14 ~)
Solve ABC175 A, B, C in Python
Use print in a Python2 lambda expression
A simple HTTP client implemented in Python
Do a non-recursive Euler Tour in Python
I made a payroll program in Python!
Write the test in a python docstring
Try sending a SYN packet in Python
Try drawing a simple animation in Python
Create a simple GUI app in Python
Draw a heart in Ruby with PyCall
Draw Nozomi Sasaki in Excel with python
Make a relation diagram of Python module
Create a JSON object mapper in Python
[Python] [Windows] Take a screen capture in Python
Run the Python interpreter in a script
How to get a stacktrace in python
Write a Caesar cipher program in Python
Hash in Perl is a dictionary in Python
Scraping a website using JavaScript in Python
Write a simple greedy algorithm in Python
Launch a Flask app in Python Anywhere
Get a token for conoha in python
[GPS] Create a kml file in Python
Write a simple Vim Plugin in Python 3
Generate a class from a string in Python