Enregistrer une variable spécifique dans tensorflow.session

Utilisez tensorflow.train.Saver pour enregistrer les variables Tensorflow dans un fichier. Toutes les variables de la session sont enregistrées par la méthode décrite dans le didacticiel. Pour enregistrer / restaurer uniquement une variable spécifique, attribuez à la fonction d'initialisation de tensorflow.train.Saver une liste de variables que vous souhaitez cibler dans un type de dictionnaire.

Cela permet de lire les variables individuellement à partir de plusieurs fichiers.

save.py


import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2


sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(0.0, 0.0, 0.0)

# saving only variables whose name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())

print var0.eval(), var1.eval(), var2.eval()

# saving into file
saver.save(sess, './bar_val')

restore.py


import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2

sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(1.0, 1.0, 1.0)

# restoring only variables whole name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())
print 'before restoring: ', var0.eval(), var1.eval(), var2.eval()

# restoring variable from file
saver.restore(sess, './bar_val')
print 'after restoring only var in foo: ', var0.eval(), var1.eval(), var2.eval()

Cependant, avec cette méthode, il faut faire attention aux noms longs et à la hiérarchie des espaces de noms. Par exemple

variable name-of-variable
var0 Variable:0
var1 foo/Variable:0
var2 foo/bar/Variable:0
var3 foobar/Variable:0

Dans un tel cas, l'exécution de get_particular_variables ('foo') ci-dessus renverra var1, var2 et var3. De cette façon, en fonction des conditions de recherche, des variables supplémentaires sont enregistrées, ce qui peut provoquer des bogues inattendus lors de la restauration.

Recommended Posts

Enregistrer une variable spécifique dans tensorflow.session
Entrez une valeur spécifique pour la variable dans tensorflow
Enregistrer les fichiers au format YAML avec PyYAML
Comment incorporer des variables dans des chaînes python
Cloner avec une branche / balise spécifique dans GitPython
Extraire des lignes contenant une "chaîne" spécifique avec Pandas
Comment compter les nombres dans une plage spécifique
Obtenir des lignes contenant des éléments spécifiques dans np.where
Je souhaite intégrer une variable dans une chaîne Python
dict in dict Transforme un dict en dict
Enregistrer les tweets contenant des mots-clés spécifiques sur Twitter au format CSV
Que contient cette variable (lorsque le script Python est en cours d'exécution)
[Linux] Comment mettre votre IP dans une variable
Arrêter une instance avec une balise spécifique dans Boto3
[Sublime Text 2] Toujours exécuter un fichier spécifique dans le projet
Enregistrez le modèle pystan et les résultats dans un fichier pickle
Obtenez le nombre d'éléments spécifiques dans la liste python
Ecrire un réseau de co-auteurs dans un domaine spécifique en utilisant les informations d'arxiv
Entrer simultanément des données spécifiques sur une feuille spécifique dans de nombreux Excel
Une histoire sur la tentative d'implémentation de variables privées en Python.
Prendre une capture d'écran en Python
Créer une fonction en Python
Créer un dictionnaire en Python
Remplacer la méthode save de django-models
Créer un bookmarklet en Python
Dessinez un cœur en Python
Type de gain de temps avec SQLAlchemy
Comment vérifier la taille de la mémoire d'une variable en Python
Enregistrez le fichier d'authentification Pydrive dans un répertoire différent du script
Si vous souhaitez affecter une exportation csv à une variable en python
[Golang] Vérifiez si une chaîne de caractères spécifique est incluse dans la chaîne de caractères