Notes on tf.function and Tracing

Introduction

This article is the 20th day article of TensorFlow 2.0 Advent Calendar 2019. I think the major change in TensorFlow 2.0 is that EagerExecution is now the default, allowing you to write in imperative languages and writing more freely in Pythonic style. However, on the other hand, there is a problem that performance and portability are sacrificed, but in order to solve it so that you can benefit from both Graph mode in 1.x and Eager mode in 2.x. Introduced in tf.function. In this article, I'll introduce how to use tf.function and what you should know when using it. Basically, it is a summary of Official Site, so if you want to know more, please refer to that as well.

How to use tf.function

It's easy to use, either add a collator with @ tf.function to the function that describes the heavy processing you want to optimize, or define a function and feed it to the tf.function method to add a function for Graph mode separately. It is a method of creating.

function_test.py


import tensorflow as tf

@tf.function
def f(x,y):
  return x + y

#Or

def g(x,y):
  return x * y

h = tf.function(g)

Also, if you call another function in @ tf.function, the scope extends to that function, so you don't have to bother checking all the functions and adding @ tf.function (or rather, we recommend it). Function). Therefore, it seems that you can easily benefit from Graph mode just by adding it to the heavy processing part for the time being. However, if it is a simple writing style like the one in the tutorial, this is not a big deal, but if you try to do something a little complicated, you will not think that you do not know the specification of tf.function. You need to be careful to do this. At the beginning of Official Site, there is the following description.

--Don't rely on Python's unique behavior like Object Mutation or Python's list --tf.function performs best with TensorFlow Ops rather than with Numpy or Python primitive types. --If in doubt, write for x in y

There are some items that are difficult to interpret, but I think it's easier to understand if you look at a concrete example, so let's take a look.

Experiment

First of all, for the sake of simplicity, prepare the following simple function.

tracing_test.py


import tensorflow as tf

@tf.function
def double(a):
    print("Tracing with, {}".format(a) )
    return a + a

It is a simple function that doubles the input argument and returns it, and also includes the process of printing the input argument. The argument works with integers, real numbers, and strings. Let's do this with some patterns.

tracing_test.py


print(double(3))
print()
print(double(2))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant('a')))
print()
print(double(tf.constant('b')))

The result is as follows.

Tracing with, 3
tf.Tensor(6, shape=(), dtype=int32)

Tracing with, 2
tf.Tensor(4, shape=(), dtype=int32)

Tracing with, Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with, Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

tf.Tensor(b'bb', shape=(), dtype=string)

The result is a little strange. The print statement is not executed only as a result of executing the last tf.constant ('b') as an argument. If you try running the above program again, you will get even more strange results.

tracing_test.py


print(double(3))
print()
print(double(2))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant('a')))
print()
print(double(tf.constant('b')))

The result is as follows.

tf.Tensor(6, shape=(), dtype=int32)

tf.Tensor(4, shape=(), dtype=int32)

tf.Tensor(2.2, shape=(), dtype=float32)

tf.Tensor(b'aa', shape=(), dtype=string)

tf.Tensor(b'bb', shape=(), dtype=string)

The correct value will be returned, but the print statement written in the middle will not be executed at all. What does this mean?

Tracing In fact, this strange behavior involves a process called Tracing when tf.function builds and optimizes a function on a computational graph. tf.function converts a function that describes not only TensorFlow-derived but also Python-specific processing into a computational graph. And it omits the Python-specific processing (print statement in this case) that is not actually related to the execution of the calculation graph. But why was it done at first? That's because when tf.function converts a function into a computational graph, a process called Tracing runs. Functions written in Python have no explicit type in their arguments. Therefore, while it is convenient to be able to enter various values, it is a problem from the tf.function side trying to create an optimal calculation graph. Therefore, when a value or type that has not been entered in the argument is input and the function is called for the first time, a process called Tracing is performed to execute all Python-specific processing in the function. I said "values and types that have never been included in the arguments", but strictly speaking, the criteria are as follows.

--In the case of Python primitive types, Tracing when different values come in --In the case of Python objects, if you get an object with a different id, Tracing --In the case of Tensor derived from TensorFlow, if a different type or shape comes in, Tracing

Therefore, the strange behavior of Karakuri is as follows.

tracing_test.py


print(double(3)) #Tracing because it is the value I see for the first time
print()
print(double(2)) #Tracing because it is the value I see for the first time
print()
print(double(tf.constant(1.1))) #Tracing because it is the value I see for the first time
print()
print(double(tf.constant('a'))) #Tracing because it is the first shape to see
print()
print(double(tf.constant('b'))) #Optimized graph execution because it is the type shape we saw before

Once you run Tracing, TensorFlow saves the resulting computational graph internally. Then, the next time a previously traced value or type / shape argument is entered, the optimized calculation graph will be executed. Therefore, in the above program, the Python print statement was not executed in the last call, and all the print statements were not executed when it was executed again.

So what should we do?

So I will return to the beginning,

--Don't rely on Python's unique behavior like Object Mutation or Python's list

It means that. It is Print earlier, but if you write it in tf.print instead, it will be executed every time. You can also prevent strange behavior by making full use of TensorFlow-derived functions, such as using tf.summary or using tf.Variable if you want to update various values in the function. It also improves performance. However, please note that I am not saying that you should not include any Python-specific processing. The advantage of being able to program more flexibly by being able to use it together with Pythonic writing is great. Just be aware that if you define a function without thinking about anything and add tf.function, it will behave strangely.

Summary

It's nice to be able to benefit from both Graph and Eager modes with TensorFlow 2.0, but it creates functions that rely too much on Python-specific features other than the pitfalls mentioned above. You may step on an unexpected bug. If you use TensorFlow, use methods derived from TensorFlow as much as possible, and when using Python-specific functions, design with consideration for the behavior and tracing of AutoGraph. The Official Site has a lot of other things to watch out for and how to control this behavior. If you need to implement your own model from now on and need to use tf.function, please read it.

Recommended Posts

Notes on tf.function and Tracing
Notes on * args and ** kargs
Notes on pyenv and Atom
Notes on Python and dictionary types
Notes on using post-receive and post-merge
Notes on Flask
Notes on building Python and pyenv on Mac
Notes on installing Python3 and using pip on Windows7
Notes on neural networks
Celery notes on Django
Notes on installing PycURL
Notes on using Alembic
Notes on SciPy.linalg functions
Notes and Tips on Vertical Joining of PySpark DataFrame
Notes on HDR and RAW image processing with Python
Notes on building TinyEMU and booting the Linux kernel on Emscripten
Python on Ruby and angry Ruby on Python
Notes on installing dlib on mac
Notes on python's sqlite3 module
Notes on defining PySide slots (2)
[Django] Notes on using django-debug-toolbar
Recording and playback on Linux
Notes on defining PySide slots
[Python] Notes on data analysis
Notes on optimization using Pytorch
Notes on installing Python on Mac
Notes on studying multidimensional scaling
Notes on installing pipenv on Mac
Notes on deploying pyenv with Homebrew and managing Python versions
Catalina on Mac and pyenv
Notes on installing Anaconda 3 on Windows
Notes on imshow () in OpenCV
[Python] Notes on while statements (writing style and infinite loop)
Notes on installing Python on CentOS
Notes on reading and writing float32 TIFF images in python
Python 3.6 on Windows ... and to Xamarin.
Notes on package management with conda
MQTT on Raspberry Pi and Mac
Install Mecab and mecab-python3 on Ubuntu 14.04
Install and run dropbox on Ubuntu 20.04
Install OpenCV and Chainer on Ubuntu
Install CUDA 8.0 and Chainer on Ubuntu 16.04
Notes on using MeCab from Python
[Golang] Notes on frequently used functions
Notes on how to use pywinauto
Build and install OpenCV on Windows
Survey on building and running kivi
Integrate Modelica and Python on Windows
Notes on how to use featuretools
For me: Infrastructure and network notes
Notes on installing Python using PyEnv
Install fabric on Ubuntu and try
Mastering pip and wheel on windows
Error classification (python3.x) and Debugging notes
Notes on using rstrip with python.
Notes on accessing dashDB from python
(Personal notes) Python metaclasses and metaprogramming
Homebrew and Pycharm installation instructions notes
Notes on how to use doctest
Notes on using matplotlib on the server
Notes on how to write requirements.txt