Comportement lorsque Container Trainable = False dans Keras

introduction

Souvent, vous souhaitez fixer le poids de votre réseau avec Keras et n'apprendre qu'une autre couche. C'est une note sur laquelle j'ai cherché à savoir à quoi faire attention à ce moment-là.

Versions

Vérification

Considérez le modèle suivant. model_normal.png

Supposons que vous vouliez "mettre à jour" le poids de la partie NormalContainer ici, et que parfois vous ne vouliez pas le mettre à jour.

Intuitivement, il semble bon de définir False sur la propriété Container # trainable, mais je vais essayer de voir si cela fonctionne comme prévu.

code

# coding: utf8

import numpy as np
from keras.engine.topology import Input, Container
from keras.engine.training import Model
from keras.layers.core import Dense
from keras.utils.vis_utils import plot_model



def all_weights(m):
    return [list(w.reshape((-1))) for w in m.get_weights()]


def random_fit(m):
    x1 = np.random.random(10).reshape((5, 2))
    y1 = np.random.random(5).reshape((5, 1))
    m.fit(x1, y1, verbose=False)

np.random.seed(100)

x = in_x = Input((2, ))

# Create 2 Containers shared same wights
x = Dense(1)(x)
x = Dense(1)(x)
fc_all = Container(in_x, x, name="NormalContainer")
fc_all_not_trainable = Container(in_x, x, name="FixedContainer")

# Create 2 Models using the Containers
x = fc_all(in_x)
x = Dense(1)(x)
model_normal = Model(in_x, x)

x = fc_all_not_trainable(in_x)
x = Dense(1)(x)
model_fixed = Model(in_x, x)

# Set one Container trainable=False
fc_all_not_trainable.trainable = False  # Case1

# Compile
model_normal.compile(optimizer="sgd", loss="mse")
model_fixed.compile(optimizer="sgd", loss="mse")

# fc_all_not_trainable.trainable = False  # Case2

# Watch which weights are updated by model.fit
print("Initial Weights")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_normal)

print("after training Model-Normal")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_fixed)

print("after training Model-Fixed")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))


# plot_model(model_normal, "model_normal.png ", show_shapes=True)

Créez deux Containers, fc_all et fc_all_not_trainable. Ce dernier laisse «formable» à False. Créez un Model appelé model_normal et model_fixed en l'utilisant.

Le comportement attendu est

C'est.

Poids du conteneur Autre poids
model_normal#fit() Changement Changement
model_fixed#fit() Ça ne change pas Changement

Résultat d'exécution: Cas 1

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37869808], [0.0091063408]]

Comme prévu.

Remarque: trainable = False doit être défini avant compile ()

Que faire si vous définissez trainable = False après Model # compile () (où le cas 2 est) dans le code ci-dessus?

Résultat d'exécution: Cas2

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [0.37869808], [0.0091063408]]

Idem jusqu'à ʻafter training Model-Normal, Lorsque ʻafter training Model-Fixed, le poids de Container change également.

Model # compile () fonctionne pour récupérer trainable_weights de tous les calques contenus lorsqu'il est appelé. Par conséquent, si vous ne définissez pas «entraînable» à ce stade, cela n'aura aucun sens.

Un autre point est qu'il n'est pas nécessaire de définir «formable» pour toutes les couches incluses dans le conteneur **. Container est une couche vue depuis Model. Model appelle Container # trainable_weights, mais ne retourne rien si Container # trainable est False (partie correspondante /keras/engine/topology.py#L1891)), donc tous les poids de couche contenus dans Container ne seront pas mis à jour. Il est un peu difficile de savoir s'il s'agit d'une spécification ou simplement de la mise en œuvre à ce stade, mais je pense que c'est probablement intentionnel.

à la fin

Le léger voile a été résolu.

Recommended Posts

Comportement lorsque Container Trainable = False dans Keras
Comportement lorsque plusieurs serveurs sont spécifiés dans les serveurs de noms de dnspython
Crypter faussement l'image lors de la compression
Comportement lors de la liste dans Python heapq
Vérifiez le comportement du destroyer en Python
Comportement lors du retour dans le bloc with
Changement de comportement de [Diagramme / Chronologie] dans Chorégraphe 2.5.5.5
J'étais en difficulté car le comportement du conteneur docker n'a pas changé
Différences de comportement de chaque langage LL lorsque l'index de la liste est ignoré
Placez Python3 dans le conteneur Docker d'Amazon Linux2
Comportement lors de l'enregistrement d'un objet datetime python dans MongoDB
Comportement de numpy.dot lors du passage d'un tableau 1d et d'un tableau 2d
Remarque lors de la mise de lxml du package python dans ubuntu 14.04