Je veux utiliser self avec Backpropagation (tf.custom_gradient) (tensorflow)

Écriture normale lors de l'utilisation de custom_graident
@tf.custom_gradient
def gradient_reversal(x):
  y = x
  def grad(dy):
    return - dy
  return y, grad

#Lorsqu'il est utilisé dans le modèle
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()

    def call(self, x):
        return gradient_reversal(x)
Si vous souhaitez utiliser des variables hors de portée (telles que self) dans custom_gradient
class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)
TypeError: If using @custom_gradient with a function that uses variables, then grad_fn must accept a keyword argument 'variables'.

Code de vérification

import tensorflow as tf


optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            z = self.alpha * w
            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            return z, variables

        return y, backward

    def call(self, x):
        return self.forward(x)


class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])

            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            tf.print("  alpha: ", self.alpha)
            tf.print("  grads: ", grads)
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)


for model in [MyModel(), MyModel2()]:
    print()
    print()
    print()
    print(model.name)
    for i in range(10):
        with tf.GradientTape() as tape:
            x = tf.Variable(1.0, tf.float32)
            y = model(x)

        grads = tape.gradient(y, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        tf.print("step")
        tf.print("  y:", y)
        tf.print("  grads:", grads)
        print()

Recommended Posts

Je veux utiliser self avec Backpropagation (tf.custom_gradient) (tensorflow)
Je veux utiliser le jeu de données R avec python
Je veux imprimer dans la notation d'inclusion
Je veux utiliser jar de python
Je veux utiliser Linux sur mac
Je souhaite utiliser la console IPython Qt
Je veux intégrer Matplotlib dans PySimpleGUI
Implémentation de DQN avec TensorFlow (je voulais ...)
Je souhaite utiliser Django Debug Toolbar dans les applications Ajax
Je veux faire le test de Dunnett en Python
Je veux utiliser MATLAB feval avec python
Je veux corriger Datetime.now dans le test de Django
Je veux créer une fenêtre avec Python
Je souhaite stocker les informations de la base de données dans la liste
Je veux fusionner des dictionnaires imbriqués en Python
Je souhaite utiliser le répertoire temporaire avec Python2
Je veux utiliser le solveur ceres de python
Je ne veux pas utiliser -inf avec np.log
Je souhaite utiliser ip vrf avec SONiC
[Je veux classer les images à l'aide de Tensorflow] (2) Classifions les images
Je souhaite utiliser la fonction d'activation Mish
Je veux afficher la progression en Python!
Je souhaite utiliser Python dans l'environnement de pyenv + pipenv sous Windows 10
Je souhaite intégrer une variable dans une chaîne Python
Je veux faire la transition avec un bouton sur le ballon
Je veux écrire en Python! (2) Écrivons un test
Même avec JavaScript, je veux voir Python `range ()`!
Je veux échantillonner au hasard un fichier avec Python
Je veux travailler avec un robot en python.
Je veux écrire en Python! (3) Utiliser des simulacres
[TensorFlow] Je souhaite traiter des fenêtres avec Ragged Tensor
J'ai essayé de résumer comment utiliser les pandas de python
Je souhaite utiliser OpenJDK 11 avec Ubuntu Linux 18.04 LTS / 18.10
Je veux faire quelque chose avec Python à la fin
Je veux manipuler des chaînes dans Kotlin comme Python!
Je souhaite utiliser une source de données python dans Re: Dash pour obtenir les résultats de la requête.
Je veux résoudre SUDOKU
[TensorFlow] Je souhaite maîtriser l'indexation pour Ragged Tensor
Je veux utiliser la dernière version de gcc même si je n'ai pas les privilèges sudo! !!
Je souhaite utiliser facilement les fonctions R avec le notebook ipython
Je souhaite supprimer facilement une colonne contenant NA dans R
Je veux faire quelque chose comme sort uniq en Python
Je souhaite utiliser uniquement le traitement de normalisation SudachiPy
[Python] Je souhaite utiliser l'option -h avec argparse
Je souhaite utiliser un environnement virtuel avec jupyter notebook!
[Django] Je souhaite me connecter automatiquement après une nouvelle inscription
Je veux rendre le type de dictionnaire dans la liste unique
[Introduction à Pytorch] Je souhaite générer des phrases dans des articles de presse
Je veux compter des valeurs uniques dans un tableau ou un tuple
Je veux aligner les nombres valides dans le tableau Numpy
Je veux pouvoir exécuter Python avec VS Code
Je veux ajouter un joli complément à input () en python
Je veux utiliser VS Code et Spyder sans anaconda! !! !!
Je ne voulais pas écrire la clé AWS dans le programme
[Pour ceux qui veulent utiliser TPU] J'ai essayé d'utiliser l'API de détection d'objets Tensorflow 2
Je souhaite utiliser la traduction de raccourcis comme l'application DeepL même sous Linux
Comment utiliser les classes dans Theano
[Linux] Je souhaite connaître la date à laquelle l'utilisateur s'est connecté
Mock in python - Comment utiliser mox
Je veux résoudre APG4b avec Python (seulement 4.01 et 4.04 au chapitre 4)