[TensorFlow 2] Comment vérifier le contenu de Tensor en mode graphique

introduction

Lors de l'utilisation d'une fonction personnalisée pour la fonction de perte pendant l'entraînement ou la partie de conversion de l'ensemble de données dans TensorFlow (2.x), j'ai essayé de confirmer "La valeur est-elle conforme aux attentes?" L'appel) peut ne pas afficher la valeur. Vous pouvez utiliser un débogueur tel que tfdbg``, mais voici un moyen plus simple de faire ce que l'on appelle le "débogage d'impression".

Pour tfdbgUtilisez tfdbg pour écraser Keras nan et inf --Qiita

Environnement de vérification

Dans le code ci-dessous, on suppose que ↓ est écrit.

import tensorflow as tf 

Comportement lorsque print () est effectué sur Tensor

Eager Execution est maintenant la valeur par défaut dans TensorFlow 2.x, donc tant que vous assemblez Tensor avec l'interpréteur, vous pouvez simplement faire print () '' pour voir la valeur de Tensor. Vous pouvez également obtenir la valeur ndarrayavec .numpy () ''

Si vous pouvez voir la valeur

x = tf.constant(10)
y = tf.constant(20)
z = x + y
print(z)         # tf.Tensor(30, shape=(), dtype=int32)
print(z.numpy()) # 30

Si la valeur ne peut pas être confirmée (1)

Comme l'évaluation séquentielle de la valeur de Tensor '' est lente au moment de l'apprentissage / évaluation, il est possible de définir la fonction à traiter en mode graphique en ajoutant le décorateur @ tf.function ''. ** Les opérations définies en mode graphique sont traitées ensemble comme un graphique de calcul (en dehors de Python), vous ne pouvez donc pas voir les valeurs dans le processus de calcul. ** **

@tf.function
def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# Tensor("mul:0", shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

Vous pouvez sortir le résultat du calcul en dehors de dot () '', mais vous ne pouvez pas voir la valeur à l'intérieur. De plus, même si vous appelez dot () '' plusieurs fois, le `` print () '' à l'intérieur n'est fondamentalement exécuté qu'une seule fois au moment de l'analyse du graphique.

Bien sûr, dans ce cas, vous pouvez supprimer le `` @ tf.function () '' pour voir la valeur.

def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# tf.Tensor([3 2], shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor([0 4], shape=(2,), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

Si la valeur ne peut pas être confirmée (2)

Il peut être implicitement exécuté en mode graphique, comme lors du traitement de map () '' pour tf.data.Datasetou lors de l'utilisation de votre propre fonction de perte dans Keras. Dans ce cas, vous ne pouvez pas voir la valeur au milieu avec print () '' même si ** @ tf.function n'est pas ajouté. ** **

def fourth_power(x):
    z = x * x
    print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)

# Tensor("mul:0", shape=(), dtype=int64)
# tf.Tensor(0, shape=(), dtype=int64)
# tf.Tensor(1, shape=(), dtype=int64)
# tf.Tensor(16, shape=(), dtype=int64)
# :

tf.print()

Utilisez tf.print () '' pour voir la valeur de Tensor '' dans le processus s'exécutant en mode graphique. tf.print | TensorFlow Core v2.1.0

Vous pouvez afficher les valeurs comme suit, ou vous pouvez utiliser tf.shape () '' pour afficher les dimensions et la taille de Tensor ''. Veuillez également l'utiliser pour le débogage lorsque vous vous mettez en colère lorsque les dimensions et les tailles ne correspondent pas pour une raison quelconque dans votre propre fonction.

@tf.function
def dot(x, y):
    tmp = x * y
    tf.print(tmp)
    tf.print(tf.shape(tmp))
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# [3 2] 
# [2]
# tf.Tensor(5, shape=(), dtype=int32)
# [0 4]
# [2]
# tf.Tensor(4, shape=(), dtype=int32)

Si vous souhaitez déboguer Dataset.map () '', vous pouvez utiliser tf.print () ''.

def fourth_power(x):
    z = x * x
    tf.print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)
# 0
# tf.Tensor(0, shape=(), dtype=int64)
# 1
# tf.Tensor(1, shape=(), dtype=int64)
# 4
# tf.Tensor(16, shape=(), dtype=int64)
# :

Exemple de tf.print () ne fonctionne pas bien

Ensuite, si vous ne pensez à rien et que vous devez sortir le contenu de Tensor avec tf.print () '', ce n'est pas le cas. tf.print () '' est exécuté lorsque le processus est réellement exécuté comme un graphe calculé. En d'autres termes, si vous constatez que les types et dimensions ne correspondent pas au moment de l'analyse du graphe sans effectuer ** de calcul et qu'une erreur se produit, le traitement de `` tf.print () '' ne sera pas exécuté. ** C'est compliqué ...

@tf.function
def add(x, y):
    z = x + y
    tf.print(z) #Cette impression ne sera pas exécutée
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

Dans un tel cas, il est préférable d'utiliser `` print () '' ordinaire pour vérifier si les données de la forme souhaitée sont transmises.

@tf.function
def add(x, y):
    print(x) #Cette impression est exécutée lors de l'analyse du graphe de calcul
    print(y)
    z = x + y
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# Tensor("x:0", shape=(2,), dtype=int32)
# Tensor("y:0", shape=(3,), dtype=int32)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

Comment mettre à jour les variables dans votre propre fonction

Par exemple, pensez à compter le nombre de fois que votre propre fonction a été appelée en mode graphique et à afficher le nombre d'appels.

Mauvais exemple

count = 0

@tf.function
def dot(x, y):
    global count
    tmp = x * y
    count += 1
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 1 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

"1" sera également affiché la deuxième fois. En effet, `` count + = 1 '' en tant que code Python n'est exécuté qu'une seule fois lors de l'analyse du graphique.

Un exemple qui fonctionne bien

La bonne réponse est d'utiliser tf.Variable () et ```assign_add () `` comme indiqué ci-dessous. tf.Variable | TensorFlow Core v2.1.0

count = tf.Variable(0)

@tf.function
def dot(x, y):
    tmp = x * y
    count.assign_add(1)
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 2 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

Article de référence

Amélioration des performances avec tf.function | TensorFlow Core (Document officiel)

Recommended Posts

[TensorFlow 2] Comment vérifier le contenu de Tensor en mode graphique
Comment vérifier la version de Django
Comment vérifier la taille de la mémoire d'une variable en Python
Comment vérifier la taille de la mémoire d'un dictionnaire en Python
La première intelligence artificielle. Comment vérifier la version de Tensorflow installée.
Comment vérifier si le contenu du dictionnaire est le même en Python par valeur de hachage
Comment comparer si le contenu des objets dans scipy.sparse.csr_matrix est le même
Comment mentionner un groupe d'utilisateurs avec une notification de mou, comment vérifier l'ID d'un groupe d'utilisateurs
Comment installer le framework d'apprentissage en profondeur Tensorflow 1.0 dans l'environnement Windows Anaconda
Comment vérifier en Python si l'un des éléments d'une liste est dans une autre liste
[Ubuntu] Comment supprimer tout le contenu du répertoire
Comment trouver le nombre optimal de clusters pour les k-moyennes
Comment voir le contenu du fichier ipynb du notebook Jupyter
Comment connecter le contenu de la liste dans une chaîne de caractères
Comment exécuter du code TensorFlow 1.0 en 2.0
[Super facile! ] Comment afficher le contenu des dictionnaires et des listes incluant le japonais en Python
Comment spécifier un nombre infini de tolérances dans le contrôle de validation d'argument numérique d'argparse
Comment déterminer l'existence d'un élément sélénium en Python
Comment connaître la structure interne d'un objet en Python
Comment changer la couleur du seul bouton pressé avec Tkinter
[EC2] Comment installer Chrome et le contenu de chaque commande
Exportez le contenu de ~ .xlsx dans le dossier en HTML avec Python
Comment obtenir les coordonnées de sommet d'une entité dans ArcPy
Créez une fonction pour obtenir le contenu de la base de données dans Go
Vérifiez le comportement du destroyer en Python
Comment vérifier la version d'opencv avec python
Bases de PyTorch (1) -Comment utiliser Tensor-
Comment vérifier le GAE local à partir du navigateur iPhone dans le même LAN
[Introduction à Python] Comment trier efficacement le contenu d'une liste avec le tri par liste
J'ai fait un programme pour vérifier la taille d'un fichier avec Python
J'ai essayé d'afficher la valeur d'altitude du DTM dans un graphique
Comment copier et coller le contenu d'une feuille au format JSON avec une feuille de calcul Google (en utilisant Google Colab)
Comment calculer la volatilité d'une marque
Comment utiliser la bibliothèque C en Python
Comment trouver la zone du diagramme de Boronoi
L'inexactitude de Tensorflow était due à log (0)
Résumé de la façon d'importer des fichiers dans Python 3
Comment exécuter CNN en notation système 1 avec Tensorflow 2
Comment utiliser la bibliothèque de dessins graphiques Bokeh
Résumé de l'utilisation de MNIST avec Python
Comment vérifier / extraire des fichiers dans un package RPM
Comment obtenir les fichiers dans le dossier [Python]
[TF] Comment créer Tensorflow dans un environnement Proxy
Comment passer le résultat de l'exécution d'une commande shell dans une liste en Python
Comment identifier de manière unique la source d'accès dans la vue de classe générique Django
Comment compter le nombre d'éléments dans Django et sortir dans le modèle
Un mémorandum expliquant comment exécuter la commande magique! Sudo dans Jupyter Notebook
Comment trouver le coefficient de la courbe approximative passant par les sommets en Python
Comment changer l'apparence du champ de clé étrangère non sélectionné dans le formulaire modèle de Django
Comment rendre la largeur de police du notebook jupyter mis dans pyenv égale
Comment obtenir une liste de fichiers dans le même répertoire avec python
Un moyen simple de vérifier la source des modules Python
Comment récupérer la nième plus grande valeur en Python
J'ai essayé de représenter graphiquement les packages installés en Python
Comment connaître le numéro de port du service xinetd
Comment obtenir le nom de la variable lui-même en python
Comment exécuter le module Ansible ajouté dans Ansible Tower
Modèle de script python pour lire le contenu du fichier
Comment identifier l'élément avec le plus petit nombre de caractères dans une liste Python?
Omettre les graduations du graphique après la virgule décimale dans matplotlib