Utilisez tensorflow.train.Saver pour enregistrer les variables Tensorflow dans un fichier. Toutes les variables de la session sont enregistrées par la méthode décrite dans le didacticiel. Pour enregistrer / restaurer uniquement une variable spécifique, attribuez à la fonction d'initialisation de tensorflow.train.Saver une liste de variables que vous souhaitez cibler dans un type de dictionnaire.
Cela permet de lire les variables individuellement à partir de plusieurs fichiers.
save.py
import tensorflow as tf
def get_particular_variables(name):
return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}
def define_variables(var0_value, var1_value, var2_value):
var0 = tf.Variable([var0_value])
with tf.variable_scope('foo'):
var1 = tf.Variable([var1_value])
with tf.variable_scope('bar'):
var2 = tf.Variable([var2_value])
return var0, var1, var2
sess = tf.InteractiveSession()
# defines variables
var0, var1, var2 = define_variables(0.0, 0.0, 0.0)
# saving only variables whose name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))
# initializing all of variables
sess.run(tf.initialize_all_variables())
print var0.eval(), var1.eval(), var2.eval()
# saving into file
saver.save(sess, './bar_val')
restore.py
import tensorflow as tf
def get_particular_variables(name):
return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}
def define_variables(var0_value, var1_value, var2_value):
var0 = tf.Variable([var0_value])
with tf.variable_scope('foo'):
var1 = tf.Variable([var1_value])
with tf.variable_scope('bar'):
var2 = tf.Variable([var2_value])
return var0, var1, var2
sess = tf.InteractiveSession()
# defines variables
var0, var1, var2 = define_variables(1.0, 1.0, 1.0)
# restoring only variables whole name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))
# initializing all of variables
sess.run(tf.initialize_all_variables())
print 'before restoring: ', var0.eval(), var1.eval(), var2.eval()
# restoring variable from file
saver.restore(sess, './bar_val')
print 'after restoring only var in foo: ', var0.eval(), var1.eval(), var2.eval()
Cependant, avec cette méthode, il faut faire attention aux noms longs et à la hiérarchie des espaces de noms. Par exemple
variable | name-of-variable |
---|---|
var0 | Variable:0 |
var1 | foo/Variable:0 |
var2 | foo/bar/Variable:0 |
var3 | foobar/Variable:0 |
Dans un tel cas, l'exécution de get_particular_variables ('foo') ci-dessus renverra var1, var2 et var3. De cette façon, en fonction des conditions de recherche, des variables supplémentaires sont enregistrées, ce qui peut provoquer des bogues inattendus lors de la restauration.
Recommended Posts