Set update rules for each parameter in Chainer v2

Introduction

In Chainer, UpdateRule was introduced from v2, and it is now possible to set update rules for each parameter by operating the UpdateRule instance. For example, you can change the learning rate or suppress the update of some parameters.

What are parameters?

The parameters in this post refer to the chainer.Parameter instance. chainer.Parameter is a class that inherits chainer.Variable and is used for the purpose of holding the parameters of chainer.Link. For example, chainer.functions.Convolution2D has two parameters, W and b.

UpdateRule

chainer.UpdateRule is a class that defines how to update parameters. There are derived classes that support update algorithms such as SGD. ʻUpdate Rule` has the following attributes.

You can stop updating parameters or change the learning rate by manipulating enabled or hyperparam.

When the Update Rule is generated

The UpdateRule instance of each parameter is created when you call setup () of the chainer.Optimizer instance.

Example

Suppose you build the following neural network.

class MLP(chainer.Chain):
    def __init__(self):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(2, 2)
            self.l2 = L.Linear(2, 1)

    def __call__(self, x):
        h = self.l1(x)
        h = self.l2(h)
        return h

Stop updating parameters

Parameter update can be stopped for each parameter or link.

Specify in parameter units

To prevent certain parameters from being updated, set ʻupdate_rule.enabled` to False. Example:

net.l1.W.update_rule.enabled = False

Specify in Link units

To prevent the Link from being updated, you can call disable_update (). Conversely, call ʻenable_update` to update all the parameters that Link has.

Example:

net.l1.disable_update()

Change hyperparameters

Hyperparameters such as learning rate can be changed by manipulating the attributes of hyperparam.

Example:

net.l1.W.update_rule.hyperparam.lr = 1.0

Add hook function

By calling ʻupdate_rule.add_hook, hook functions such as chainer.optimizer.WeightDecay` can be set for each parameter.

Example:

net.l1.W.update_rule.add_hook(chainer.optimizer.WeightDecay(0.0001))

Try out

As an example, let's increase the learning rate of some parameters and stop updating other parameters.

# -*- coding: utf-8 -*-

import numpy as np

import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers


class MLP(chainer.Chain):
    def __init__(self):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(2, 2)
            self.l2 = L.Linear(2, 1)

    def __call__(self, x):
        h = self.l1(x)
        h = self.l2(h)
        return h


net = MLP()
optimizer = optimizers.SGD(lr=0.1)

#Call setup will generate an Update Rule
optimizer.setup(net)

net.l1.W.update_rule.hyperparam.lr = 10.0
net.l1.b.update_rule.enabled = False

x = np.asarray([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
t = np.asarray([[0], [1], [1], [0]], dtype=np.int32)

y = net(x)

print('before')
print('l1.W')
print(net.l1.W.data)
print('l1.b')
print(net.l1.b.data)
print('l2.W')
print(net.l2.W.data)
print('l2.b')
print(net.l2.b.data)

loss = F.sigmoid_cross_entropy(y, t)
net.cleargrads()
loss.backward()
optimizer.update()

print('after')
print('l1.W')
print(net.l1.W.data)
print('l1.b')
print(net.l1.b.data)
print('l2.W')
print(net.l2.W.data)
print('l2.b')
print(net.l2.b.data)

The execution result is as follows. You can see that the amount of change in l1.W is much larger than the amount of change in l2.W, and l1.b has not changed.

before
l1.W
[[ 0.0049778  -0.16282777]
 [-0.92988533  0.2546134 ]]
l1.b
[ 0.  0.]
l2.W
[[-0.45893994 -1.21258962]]
l2.b
[ 0.]
after
l1.W
[[ 0.53748596  0.01032409]
 [ 0.47708291  0.71210718]]
l1.b
[ 0.  0.]
l2.W
[[-0.45838338 -1.20276082]]
l2.b
[-0.01014706]

Recommended Posts

Set update rules for each parameter in Chainer v2
How to set the output resolution for each keyframe in Blender
[Deprecated] Chainer v1.24.0 Tutorial for beginners