I want to use self in Backpropagation (tf.custom_gradient) (tensorflow)

Normal way of writing when using custom_graident
@tf.custom_gradient
def gradient_reversal(x):
  y = x
  def grad(dy):
    return - dy
  return y, grad

#When used in model
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()

    def call(self, x):
        return gradient_reversal(x)
If you want to use out-of-scope variables (such as self) in custom_gradient
class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)
TypeError: If using @custom_gradient with a function that uses variables, then grad_fn must accept a keyword argument 'variables'.

Verification code

import tensorflow as tf


optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            z = self.alpha * w
            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            return z, variables

        return y, backward

    def call(self, x):
        return self.forward(x)


class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])

            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            tf.print("  alpha: ", self.alpha)
            tf.print("  grads: ", grads)
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)


for model in [MyModel(), MyModel2()]:
    print()
    print()
    print()
    print(model.name)
    for i in range(10):
        with tf.GradientTape() as tape:
            x = tf.Variable(1.0, tf.float32)
            y = model(x)

        grads = tape.gradient(y, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        tf.print("step")
        tf.print("  y:", y)
        tf.print("  grads:", grads)
        print()

Recommended Posts

I want to use self in Backpropagation (tf.custom_gradient) (tensorflow)
I want to use the R dataset in python
I want to print in a comprehension
I want to use jar from python
I want to use Linux on mac
I want to use IPython Qt Console
I want to embed Matplotlib in PySimpleGUI
Implemented DQN in TensorFlow (I wanted to ...)
I want to use the Django Debug Toolbar in my Ajax application
I want to do Dunnett's test in Python
I want to use MATLAB feval with python
I want to pin Datetime.now in Django tests
I want to create a window in Python
I want to store DB information in list
I want to merge nested dicts in Python
I want to use Temporary Directory with Python2
I want to use ceres solver from python
I don't want to use -inf with np.log
I want to use ip vrf with SONiC
[I want to classify images using Tensorflow] (2) Let's classify images
I want to use the activation function Mish
I want to display the progress in Python!
I want to use Python in the environment of pyenv + pipenv on Windows 10
I want to embed a variable in a Python string
I want to transition with a button in flask
I want to write in Python! (2) Let's write a test
Even in JavaScript, I want to see Python `range ()`!
I want to randomly sample a file in Python
I want to work with a robot in python.
I want to write in Python! (3) Utilize the mock
[TensorFlow] I want to process windows with Ragged Tensor
I tried to summarize how to use pandas in python
I want to use OpenJDK 11 on Ubuntu Linux 18.04 LTS / 18.10
I want to do something in Python when I finish
I want to manipulate strings in Kotlin like Python!
I want to use a python data source in Re: Dash to get query results
I want to use a network defined by myself in PPO2 of Stable Baselines
I want to solve Sudoku (Sudoku)
[TensorFlow] I want to master the indexing for Ragged Tensor
I want to use the latest gcc without sudo privileges! !!
I want to use R functions easily with ipython notebook
I want to easily delete columns containing NA in R
I want to do something like sort uniq in Python
I want to use only the normalization process of SudachiPy
[Python] I want to use the -h option with argparse
I want to use a virtual environment with jupyter notebook!
[Django] I want to log in automatically after new registration
I want to make the Dictionary type in the List unique
[Introduction to Pytorch] I want to generate sentences in news articles
I want to count unique values in arrays and tuples
I want to align the significant figures in the Numpy array
I want to be able to run Python in VS Code
I want to make input () a nice complement in python
I want to use VS Code and Spyder without anaconda! !! !!
I didn't want to write the AWS key in the program
[For those who want to use TPU] I tried using the Tensorflow Object Detection API 2
I want to use shortcut translation like DeepL app on Linux
How to use classes in Theano
[Linux] I want to know the date when the user logged in
Mock in python-how to use mox
I want to solve APG4b with Python (only 4.01 and 4.04 in Chapter 4)