Behavior when Trainable = False of Container in Keras

Introduction

Often you want to fix the weight of your network in Keras and learn only another layer. It is a memo that I investigated what to be careful about at that time.

Versions

Verification

Consider the following Model. model_normal.png

Suppose you want to "update" the Weight of the NormalContainer part here, and sometimes you don't want to update it.

Intuitively, it seems good to set False to the Property Container # trainable, but I will try to see if it works as intended.

code

# coding: utf8

import numpy as np
from keras.engine.topology import Input, Container
from keras.engine.training import Model
from keras.layers.core import Dense
from keras.utils.vis_utils import plot_model



def all_weights(m):
    return [list(w.reshape((-1))) for w in m.get_weights()]


def random_fit(m):
    x1 = np.random.random(10).reshape((5, 2))
    y1 = np.random.random(5).reshape((5, 1))
    m.fit(x1, y1, verbose=False)

np.random.seed(100)

x = in_x = Input((2, ))

# Create 2 Containers shared same wights
x = Dense(1)(x)
x = Dense(1)(x)
fc_all = Container(in_x, x, name="NormalContainer")
fc_all_not_trainable = Container(in_x, x, name="FixedContainer")

# Create 2 Models using the Containers
x = fc_all(in_x)
x = Dense(1)(x)
model_normal = Model(in_x, x)

x = fc_all_not_trainable(in_x)
x = Dense(1)(x)
model_fixed = Model(in_x, x)

# Set one Container trainable=False
fc_all_not_trainable.trainable = False  # Case1

# Compile
model_normal.compile(optimizer="sgd", loss="mse")
model_fixed.compile(optimizer="sgd", loss="mse")

# fc_all_not_trainable.trainable = False  # Case2

# Watch which weights are updated by model.fit
print("Initial Weights")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_normal)

print("after training Model-Normal")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_fixed)

print("after training Model-Fixed")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))


# plot_model(model_normal, "model_normal.png ", show_shapes=True)

Create two Containers, fc_all and fc_all_not_trainable. The latter leaves trainable set to False. Create Model called model_normal and model_fixed using it.

The expected behavior is

That is.

Container Weight Other Weight
model_normal#fit() Change Change
model_fixed#fit() It does not change Change

Execution result: Case1

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37869808], [0.0091063408]]

As expected.

Note: trainable = False must be set before compile ()

What if you set trainable = False afterModel # compile ()(where Case 2 is) in the above code?

Execution result: Case2

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [0.37869808], [0.0091063408]]

Same up to ʻafter training Model-Normal, When ʻafter training Model-Fixed, the weight of Container also changes.

Model # compile () works to retrieve trainable_weights from all contained Layers when called. Therefore, if you do not set trainable at that point, it will be meaningless.

Another point is that ** it is not necessary to set trainable for all layers included in Container **. Container is one layer when viewed from Model. Model calls Container # trainable_weights, but returns nothing if Container # trainable is False (corresponding /keras/engine/topology.py#L1891)), so all Layer Weights contained in Container will not be updated. It's a bit unclear if this is a spec or just the implementation at this stage, but I think it's probably intentional.

at the end

The slight haze has been resolved.

Recommended Posts

Behavior when Trainable = False of Container in Keras
Behavior when multiple servers are specified in nameservers of dnspython
False encryption of images when squeezing
Behavior when listing in Python heapq
Check the behavior of destructor in Python
Behavior when returning in the with block
Behavior change of [Diagram / Timeline] in Choregraphe 2.5.5.5
Behavior when SIGEV_THREAD is set in sigev_notify of sigevent with timer_create (C language)
I was in trouble because the behavior of docker container did not change
Differences in the behavior of each LL language when the list index is skipped
About the behavior of Model.get_or_create () of peewee in Python
Put Python3 in Docker container of Amazon Linux2
Behavior when saving python datetime object in MongoDB
Behavior of numpy.dot when passing 1d array and 2d array
Note when putting lxml of python package in ubuntu 14.04