When I received annotation (weighting) information from the user, I tried and errored how to implement a deep learning network model that switches processing in Keras, so I will summarize the contents. Recently, it was mainly implemented in PyTorch, so I was confused by the difference in the description method. Use the functional API to describe complex networks in Keras. Reference: Qiita page of keras functional API usage memo
In the Functional API, it is necessary to connect the layers defined in keras.layers. It is necessary to implement using Lambda to put a layer of original processing like this time. Below is a code example of a network that assumes an image recognition task as shown.
Normally, the original image is given as input data to the network. This time, in addition to that, a weighting map corresponding to the original image and a flag (0 or 1) as to whether to use the weighting map give a total of three inputs. Flags for using original images, weighted maps, and weighted maps are defined in keras.layers.Input as they vary in size depending on the input dataset. Therefore, it is not possible to simply judge by the If statement and switch the process.
from keras.models import Model
from keras.layers import Conv2D, Activation, BatchNormalization, GlobalAveragePooling2D, Dense, Input, Lambda, Add, Multiply
from keras.backend import switch as k_switch
from keras.backend import equal as k_equal
import numpy as np
def net(x, user_weight_map, user_weight_map_flg, feature_ch=16):
"""
x:Original image
user_weight_map:User-given weighting map
user_weight_map_flg:Flag to use weighted map given by user
"""
#Apply Convolution 4 times
h = Conv2d(feature_ch, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*2, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*4, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*8, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
#---------------------
#Weighted map in branch network(Self Attention)To calculate
bh = Conv2D(feature_ch*4, 3, strides=1, padding='same')(h)
bh = BatchNormalization()(bh)
bh = Activation(activation='relu')(bh)
bh = Conv2D(feature_ch*2, 3, strides=1, padding='same')(bh)
bh = BatchNormalization()(bh)
bh = Activation(activation='relu')(bh)
bh = Conv2D(2, 1, strides=1, padding='same')(bh)
bh = BatchNormalization()(bh)
bh = Activation(activation='relu')(bh)
model_weight = Conv2D(1, 3, strides=1, padding='same')(bh)
model_weight = BatchNormalization()(bh)
model_weight = Activation(activation='sigmoid', name='model_weight_output')(bh)
bh = Conv2D(2, 1, strides=1, padding='same')(bh)
bh = GlobalAveragePooling2D()(ah)
bh = Dense(1000)(bh)
#---------------------
#Read the flag information and switch between using the weighting map calculated from the network and using the weighting map created by the user.
weight_h = Lambda(lambda x: switch_weight_map(x), name='swith_weight_map')([h, model_weight, user_weight_map, user_weight_map_flg])
h = Add(name='weight_map_add')([h, weight_h])
h = Conv2d(feature_ch*16, 3, strides=2, padding='same')(h)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*32, 3, strides=2, padding='same')(h)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = GlobalAveragePooling2D()(h)
h = Dense(1000)(h)
return h, bh, model_weight
def switch_weight_map(inputs):
feature_map = inputs[0]
model_weight_map = inputs[1]
user_weight_map = inputs[2]
user_weight_map_flg = inputs[3]
model_weight = Multiply()([feature_map, model_weight_map])
user_weight = Multiply()([feature_map, user_weight_map])
weight_cond = k_equal(user_weight_map_flg, 0)
weight_h = k_switch(weight_cond, model_weight, user_weight)
return weight_h
# Save Network Architecture
def save_network_param(save_path, feature_ch):
param = {'base_feature_num':feature_ch}
with open(save_path, 'w') as f:
yaml.dump(param, f, default_flow_style=False)
# Load Network Architecture
def load_network_param(load_path):
with open(load_path) as f:
param = yaml.load(f)
return param
While turning the training process, if you try to save the model for each epoch with the argument save_weights_only = False in the callback function keras.callbacks.ModelCheckpoint (), the error message is "can't pickle _thread.RLock objects". Something like that came out. Also, when I tried to export the model with model.to_json () or model.to_yaml (), I got the same error. It seemed that it was not possible to serialize pickle because there was an amorphous Input until Lambda was given input data. In keras.callbacks.ModelCheckpoint (), save the argument with save_weights_only = True. Prepare save_network_param () and load_network_param (), and to use the model created by train in predict, reproduce the network structure with the network code and the exported yaml file, and set the weight of each layer with model.load_weights (). ..
In the implementation using Lambda, the trick was to give the argument x as a list like [h, model_weight, user_weight_map, user_weight_map_flg]. If you take only user_weight_map_flg for Lambda's argument x as shown below, Keras will interpret the network structure and it will not be possible to determine whether model_weight will be connected to other layers when saving or loading. Could not.
weight_h = Lambda(lambda x:k_switch(k_equal(x, 0), model_weight, user_weight), name='switch_weight_map')(user_weight_map_flg)
https://stackoverflow.com/questions/52448652/attributeerror-nonetype-object-has-no-attribute-inbound-nodes-while-trying https://stackoverflow.com/questions/44855603/typeerror-cant-pickle-thread-lock-objects-in-seq2seq https://github.com/keras-team/keras/issues/8343 https://github.com/matterport/Mask_RCNN/issues/1126 https://stackoverflow.com/questions/53212672/read-only-mode-in-keras https://stackoverflow.com/questions/47066635/checkpointing-keras-model-typeerror-cant-pickle-thread-lock-objects/55229794#55229794 https://blog.shikoan.com/lambda_arguments/ https://github.com/keras-team/keras/issues/6621 https://stackoverflow.com/questions/59635570/keras-backend-k-switch-for-loss-function-error
Recommended Posts