Que faire si vous obtenez une erreur de remplacement obligatoire `get_config` lorsque vous essayez de model.save avec Keras

introduction

Dans cet article, je voudrais vous présenter ce qu'il faut faire lorsque vous essayez de model.save (ou model.to_json) avec Keras et d'obtenir XX a des arguments dans __init__ et doit donc remplacer get_config``` . ..

Contexte et cause

Dans Keras, il existe de nombreuses couches prédéfinies telles que la couche Dense et la couche Conv, et celles-ci sont combinées pour concevoir un modèle de base. Mais plus avancé, vous devrez implémenter vos propres couches personnalisées et les ajouter à votre modèle. Par exemple, si vous souhaitez profiter du mécanisme publié dans le dernier article, il n'existe pas dans la couche prédéfinie Keras et vous devez le citer depuis Github ou l'implémenter vous-même. (Si vous souhaitez implémenter une couche personnalisée, consultez cet exemple officiel (https://github.com/keras-team/keras/blob/master/examples/antirectifier.py).) Alternativement, les débutants peuvent utiliser sans le savoir un modèle qui inclut une couche personnalisée lors de la création d'un script publié sur le noyau de kaggle. (J'ai moi-même fait face à cette erreur de cette façon.)

À propos, l'erreur XX (nom du calque personnalisé) a des arguments dans __init__ et doit donc remplacer get_config``` ne peut pas être traitée correctement pour le modèle comprenant cette ** couche personnalisée ** À ce moment-là, Keras m'a mis en colère, "Je ne connais pas une telle couche."

Solution

Cela peut être résolu en remplaçant get_config () '' dans la classe de couche personnalisée. Plus précisément, ** faites l'argument de init``` de la classe de couche personnalisée dans un dictionnaire, ajoutez-le à la configuration de la classe parente et retournez **, etc. get_config () Définir Cela signifie que l'argument de init '' est comme le document de conception de cette couche personnalisée, donc je l'ai fait arbitrairement ** J'enseigne explicitement à Keras comment fonctionne la couche personnalisée * * Équivalent à.

À propos, les modèles enregistrés de cette manière doivent également indiquer explicitement leur couche personnalisée dans l'argument custom_objects lors du chargement. La méthode est très simple, procédez comme suit:

load_model('my_model.h5', custom_objects={'NameOfCustomLayer': NameOfCustomLayer})

Exemple concret

Prenons l'exemple du noyau public de Kaggle. [GLRec] ResNet50 ArcFace (TF2.2)

Dans ce script, la définition réelle du modèle est effectuée ci-dessous. Le modèle de backbone est prédéfini dans Keras dans ResNet 50. (Le poids peut être obtenu non seulement en utilisant celui enregistré localement comme cette fois, mais également par le package Keras.) Les couches de regroupement et de suppression sont également prédéfinies.

En cela, vous pouvez voir que seule la couche de marge est instanciée de manière unique. Il s'agit du calque personnalisé pour ce modèle.

create_model.py



def create_model(input_shape,
                 n_classes,
                 dense_units=512,
                 dropout_rate=0.0,
                 scale=30,
                 margin=0.3):

    backbone = tf.keras.applications.ResNet50(
        include_top=False,
        input_shape=input_shape,
        weights=('../input/imagenet-weights/' +
                 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
    )

    pooling = tf.keras.layers.GlobalAveragePooling2D(name='head/pooling')
    dropout = tf.keras.layers.Dropout(dropout_rate, name='head/dropout')
    dense = tf.keras.layers.Dense(dense_units, name='head/dense')

    margin = ArcMarginProduct(
        n_classes=n_classes,
        s=scale,
        m=margin,
        name='head/arc_margin',
        dtype='float32')

    softmax = tf.keras.layers.Softmax(dtype='float32')

    image = tf.keras.layers.Input(input_shape, name='input/image')
    label = tf.keras.layers.Input((), name='input/label')

    x = backbone(image)
    x = pooling(x)
    x = dropout(x)
    x = dense(x)
    x = margin([x, label])
    x = softmax(x)
    return tf.keras.Model(
        inputs=[image, label], outputs=x)

Vérifiez la classe de couche de marge ArcMarginProduct. Ensuite, vous pouvez voir qu'il s'agit d'un calque personnalisé qui hérite de tf.keras.layers.Layer. (À propos, la technologie mise en œuvre s'appelle ArcFace.)

Dans cette couche personnalisée définie par moi-même, quand je n'ai pas correctement remplacé get_config () '' ``, j'ai rencontré l'erreur au début lorsque j'ai fait model.save.

Dans ce noyau, get_config () '' n'est pas défini dans la classe, donc si vous essayez de sauvegarder tel quel, une erreur se produira.

custom_layer.py


class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

Par conséquent, vous devez apporter les modifications suivantes.

Plus précisément, il remplace get_config () '' et renvoie l'argument `` init` '' et la configuration de la classe parente.

new_custom_layer.py



class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m


###Commencer le code ajouté
    def get_config(self):
        config = {
            "n_classes" : self.n_classes,
            "s" : self.s,
            "m" : self.m,
            "easy_margin" : self.easy_margin,
            "ls_eps" : self.ls_eps
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

###  End       
        
    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

Et le modèle doit être chargé comme suit.

load_model.py



loaded_model =keras.models.load_model("path_to_model", custom_objects = {"ArcMarginProduct": ArcMarginProduct})

référence

Comment créer un calque personnalisé avec Keras Sérialiser les calques personnalisés avec des keras NotImplementedError: Layers with arguments in __init__ must override get_config

Recommended Posts

Que faire si vous obtenez une erreur de remplacement obligatoire `get_config` lorsque vous essayez de model.save avec Keras
Que faire si vous obtenez une erreur en essayant d'envoyer un message dans tasks.loop () immédiatement après le démarrage
Que faire si vous obtenez une erreur lors du chargement de mnist
Que faire si vous obtenez une erreur "Aucune version trouvée" sur pipenv
Que faire si vous obtenez une erreur de mémoire lors de la conversion de PySparkDataFrame en PandasDataFrame
Que faire si vous obtenez une erreur lors de l'importation de matplotlib en Python (Mac)
Que faire si vous obtenez une erreur Impossible de récupérer le lien métallique pour le référentiel avec yum
Que faire si vous obtenez une erreur lors de l'exécution de "certbot renouveler" dans l'environnement CakePHP
Que faire si vous obtenez une erreur non définie lorsque vous essayez d'utiliser pip avec pyenv
Que faire si vous obtenez l'erreur RuntimeError: Python n'est pas installé en tant que framework lorsque vous essayez d'utiliser matplitlib et pylab dans Python 3.3
Que faire si vous obtenez une erreur lors de l'installation de python avec pyenv
[Python] Choses à vérifier lorsqu'une erreur de décodage Unicode apparaît dans Django
Que faire si Combinaisons devient «couverture inconnue»
Que faire si une erreur 0xC0000005 se produit dans tf.train.start_queue_runners ()
Que faire si vous obtenez une erreur OpenSSL lors de l'installation de Python 2 avec pyenv
Que faire si vous obtenez "(35, 'Erreur de connexion SSL')" dans pycurl (l'un d'entre eux)
Que faire si vous obtenez une erreur d'importation lors de l'importation de matplotlib avec Jupyter
Que faire si vous obtenez l'erreur ʻERR_FEATURE_UNAVAILABLE_ON_PLATFORM` lors de l'utilisation de ts-node-dev sous Linux
Que faire si vous obtenez une erreur de décodage Unicode avec l'installation de pip
Que faire lorsque swagger-codegen est terminé avec python et Erreur d'importation: aucun module nommé n'apparaît
Que faire si vous recevez une erreur d'appel avec trop d'arguments d'entrée à faire et retourner dans un test de golang
Que faire si vous vous perdez dans la référence de fichier avec FileNotFoundError
Que faire si vous vous fâchez avec TensorFlow v2 sans l'attribut "app"
Que faire si vous obtenez une erreur indiquant que le compilateur C ne peut pas créer d'exécutables dans configure
Que faire lorsque TypeError se produit au minimum et au maximum de numpy
Que faire si vous obtenez `locale.Error: unsupported locale setting` lors de l'obtention de la date du jour en Python
Que faire lorsque vous vous fâchez avec "Value Error: unknown local: UTF-8" dans python manage.py syncdb
Que faire si une erreur de lien symbolique se produit dans l'importation cv lors de la tentative d'installation d'OpenCV en Python
Solution de contournement si vous obtenez une erreur lors de la tentative d'installation de PySide avec pip
Que faire si une erreur de codage Unicode se produit dans Sublime Text Python
Que faire si vous obtenez «Python non configuré». Utilisation de PyDev dans Eclipse
Que faire si une erreur de version se produit dans le pilote Selenium Chrome
Que faire si une erreur de décodage Unicode se produit dans pip
Que faire si on vous dit «Erreur d'importation: impossible d'importer le nom'HTTPSHandler '» lors de la création d'un environnement virtuel à l'aide de virtualenv
Que faire lorsqu'une erreur "service inconnu" est renvoyée par le serveur gRPC
Que faire quand "Aucun noyau pour le langage python trouvé" apparaît dans Hydrogen
Que faire si vous exécutez python sur IntelliJ et quittez avec une erreur
Que faire si pip donne une DistributionError dans Homebrew
Que faire lorsqu'une erreur de suppression se produit lors de la mise à jour de conda
Que faire si vous ne pouvez pas vous connecter en tant que root
Que faire si vous obtenez l'erreur "Erreur: opencv3: Ne prend pas en charge la construction des wrappers Python 2 et 3" lors de l'installation d'openCV 3
Que faire lorsque vous obtenez une erreur indiquant «Échec temporaire de la résolution du nom» sous Linux
Que faire si vous obtenez une erreur lors du vagabondage lorsque vous activez public_network ou private_network sur Vagrant + Arch Linux → Installer netctl
Que faire si vous ne pouvez pas utiliser la poubelle dans Lubuntu 18.04.
Que faire lorsque vous obtenez "Je ne peux pas voir le site !!!!"
Que faire si vous vous mettez en colère si vous n'avez pas libxml / xmlversion.h lors de l'installation de lxml sur CentOS
Que faire lorsque pip --user renvoie une erreur dans un environnement virtuel créé avec pyenv
Que faire s'il y a un décimal dans python json .dumps
Que faire si PDO n'est pas trouvé dans Laravel ou CakePHP
Lors d'une erreur de programmation: (1146, "La table '<nom de la table>' n'existe pas") se produit dans Django
Que faire si vous ne pouvez pas utiliser la recherche de grille de sklearn en Python
Que faire si vous êtes bloqué pendant l'installation d'Anaconda sur Linux
Que faire si une erreur se produit lors de l'importation de numpy avec VScode
Que faire si vous ne pouvez pas installer avec pip dans l'environnement babun
Que faire si vous obtenez Impossible de récupérer l'URL 443 avec pip
Que se passe-t-il si vous "importez A, B comme C" en Python?
[OSX] [pyenv] Que faire lorsqu'une erreur SSL se produit dans pip
Que faire lorsqu'un message d'avertissement est affiché dans la liste des pip