Save a specific variable in tensorflow.session

Use tensorflow.train.Saver to save Tensorflow Variables to a file. All Variables in Session are saved by the method described in Tutorial. To save / restore only a specific Variable, give the initialization function of tensorflow.train.Saver a list of Variables you want to target in a dictionary type.

This makes it possible to read Variables individually from multiple files.

save.py


import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2


sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(0.0, 0.0, 0.0)

# saving only variables whose name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())

print var0.eval(), var1.eval(), var2.eval()

# saving into file
saver.save(sess, './bar_val')

restore.py


import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2

sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(1.0, 1.0, 1.0)

# restoring only variables whole name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())
print 'before restoring: ', var0.eval(), var1.eval(), var2.eval()

# restoring variable from file
saver.restore(sess, './bar_val')
print 'after restoring only var in foo: ', var0.eval(), var1.eval(), var2.eval()

However, this method requires attention to long names and namespace hierarchies. For example

variable name-of-variable
var0 Variable:0
var1 foo/Variable:0
var2 foo/bar/Variable:0
var3 foobar/Variable:0

In such a case, executing get_particular_variables ('foo') above will return var1, var2, and var3. In this way, depending on the search conditions, extra variables are saved, which may cause unexpected bugs when restoring.

Recommended Posts

Save a specific variable in tensorflow.session
Enter a specific value for variable in tensorflow
Save a YAML-formatted file in PyYAML
How to embed a variable in a python string
Clone with a specific branch / tag in GitPython
Extract lines containing a specific "string" in Pandas
Sort dict in dict (dictionary in dictionary) with a specific key
How to count numbers in a specific range
Get a row containing a specific element in np.where
I want to embed a variable in a Python string
dict in dict Makes a dict a dict
Save tweets containing specific keywords in CSV on Twitter
What's in that variable (when running a Python script)
[Linux] How to put your IP in a variable
Stop an instance with a specific tag in Boto3
[Sublime Text 2] Always execute a specific file in the project
Save the pystan model and results in a pickle file
Temporarily save a Python object and reuse it in another Python
Get the number of specific elements in a python list
Write a co-author network in a specific field using arxiv information
Simultaneously input specific data to a specific sheet in many excels
A story about trying to implement a private variable in Python.
Take a screenshot in Python
Create a function in Python
Create a dictionary in Python
Override save method in django-models
Collaborate in a remote environment
Make a bookmarklet in Python
Draw a heart in Python
Save Time type in SQLAlchemy
How to check the memory size of a variable in Python
Save the Pydrive authentication file in a different directory from the script
If you want to assign csv export to a variable in python
[Golang] Check if a specific character string is included in the character string