J'ai essayé le réglage fin de CNN avec Resnet

Réglage fin avec Python Resnet

** Cet article présente le réglage fin à l'aide de Resnet. ** **

Cette fois, afin de vérifier l'efficacité du réglage fin, il est normal Je voudrais comparer [Modèle créé avec CNN] et [Modèle affiné].

Environnement de développement

J'ai appris avec Google Colab, mais je vais le décrire comme possible même avec des hypothèses locales. python : 3.7.0 keras : 2.4.3 tensorflow : 2.2.0

Qu'est-ce que Resnet

Resnet est un réseau neuronal convolutif [appris] qui dépasse 1 million de feuilles dans une base de données appelée Imagenet. Et ce réseau, également appelé Resnet 50, comporte 50 couches et peut être classé en 1000 genoux de catégorie.

Qu'est-ce que le réglage fin?

Recyclage des pondérations de l'ensemble du modèle, avec les pondérations de réseau entraînées comme valeur initiale. Par conséquent, nous visons à faire un meilleur discriminateur en réapprenant après avoir utilisé le Resnet50 ci-dessus.

Image à utiliser

Cela me paraissait intéressant, j'ai donc essayé d'utiliser 3 types de cigarettes japonaises. ** Mobius: 338 feuilles ** ** Seven Star: 552 feuilles ** ** Winston: 436 feuilles **

Cette fois, il s'agit d'une vérification du réglage fin, donc l'image ci-dessus a été extraite d'Internet.

CNN normal

La configuration est la suivante. 1.png

L'apprentissage de cela a donné les résultats suivants. tobacco_dataset_reslut.jpg

C'est assez mauvais parce que nous n'avons effectué aucun prétraitement tel que le recadrage de l'image. Est-ce environ [60%] dans les données de test?

La source

normal_cnn.py


from PIL import Image
import numpy as np
import glob
import os
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers.convolutional import MaxPooling2D
from keras.layers import  Conv2D, Flatten, Dense, Dropout
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

root = "tobacco_dataset"
folder = os.listdir(root)
image_size = 224
dense_size = len(folder)
epochs = 30
batch_size = 16

X = []
Y = []
for index, name in enumerate(folder):
    dir = "./" + root + "/" + name
    print("dir : ", dir)
    files = glob.glob(dir + "/*")
    print("number : " + str(files.__len__()))
    for i, file in enumerate(files):
      try:
        image = Image.open(file)
        image = image.convert("RGB")
        image = image.resize((image_size, image_size))
        data = np.asarray(image)
        X.append(data)
        Y.append(index)
      except :
          print("read image error")

X = np.array(X)
Y = np.array(Y)
X = X.astype('float32')
X = X / 255.0

Y = np_utils.to_categorical(Y, dense_size)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.15)

model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(image_size, image_size, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(dense_size, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

result = model.fit(X_train, y_train, validation_split=0.15, epochs=epochs, batch_size=batch_size)

x = range(epochs)
plt.title('Model accuracy')
plt.plot(x, result.history['accuracy'], label='accuracy')
plt.plot(x, result.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), borderaxespad=0, ncol=2)

name = 'tobacco_dataset_reslut.jpg'
plt.savefig(name, bbox_inches='tight')
plt.close()


Réglage fin avec Resnet

Comme la composition est assez profonde et ne peut pas être collée sur Qiita, je décrirai l'URL https://github.com/daichimizuno/cnn_fine_tuning/blob/master/finetuning_layer.txt

En plus du Resnet d'origine, les fonctions d'activation Relu et dropout sont ajoutées de 0,5. À propos, les sources suivantes sont les sources lors de l'utilisation de Resnet. Il semble que Keras contient une bibliothèque pour Resnet.


ResNet50 = ResNet50(include_top=False, weights='imagenet',input_tensor=input_tensor)

L'apprentissage de cela a donné les résultats suivants. resnet_tobacco_dataset_reslut.jpg

Est-ce environ [85%] dans les données de test?

La source

resnet_fine_tuning.py


from PIL import Image
import numpy as np
import glob
import os
from keras.utils import np_utils
from keras.models import Sequential, Model
from keras.layers import Flatten, Dense,Input, Dropout
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from keras.applications.resnet50 import ResNet50
from keras import optimizers

root = "tobacco_dataset"
folder = os.listdir(root)
image_size = 224
dense_size = len(folder)
epochs = 30
batch_size = 16

X = []
Y = []
for index, name in enumerate(folder):
    dir = "./" + root + "/" + name
    print("dir : ", dir)
    files = glob.glob(dir + "/*")
    print("number : " + str(files.__len__()))
    for i, file in enumerate(files):
      try:
        image = Image.open(file)
        image = image.convert("RGB")
        image = image.resize((image_size, image_size))
        data = np.asarray(image)
        X.append(data)
        Y.append(index)
      except :
          print("read image error")

X = np.array(X)
Y = np.array(Y)
X = X.astype('float32')
X = X / 255.0

Y = np_utils.to_categorical(Y, dense_size)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.15)

input_tensor = Input(shape=(image_size, image_size, 3))
ResNet50 = ResNet50(include_top=False, weights='imagenet',input_tensor=input_tensor)

top_model = Sequential()
top_model.add(Flatten(input_shape=ResNet50.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(dense_size, activation='softmax'))

top_model = Model(input=ResNet50.input, output=top_model(ResNet50.output))
top_model.compile(loss='categorical_crossentropy',optimizer=optimizers.SGD(lr=1e-3, momentum=0.9),metrics=['accuracy'])

top_model.summary()
result = top_model.fit(X_train, y_train, validation_split=0.15, epochs=epochs, batch_size=batch_size)

x = range(epochs)
plt.title('Model accuracy')
plt.plot(x, result.history['accuracy'], label='accuracy')
plt.plot(x, result.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), borderaxespad=0, ncol=2)

name = 'resnet_tobacco_dataset_reslut.jpg'
plt.savefig(name, bbox_inches='tight')
plt.close()


Résumé

En comparant CNN normal et le réglage fin, vous pouvez voir que le réglage fin s'est nettement amélioré à partir de l'époque 10 environ. Puisqu'il s'agit d'un ré-apprentissage à l'aide de Resnet, j'imagine qu'il se propagera dans la partie où le taux de réponse correcte de Resnet augmente et affecte la discrimination des cigarettes, mais une analyse plus détaillée est nécessaire. ..

Cependant, dans tous les cas, le résultat était considérablement plus élevé que celui d'un CNN normal en termes de jugement sur le tabac. Je pense que des résultats encore meilleurs seront obtenus si vous sélectionnez des images et effectuez un prétraitement, donc si vous le pouvez, essayez-le!

【Github】 https://github.com/daichimizuno/cnn_fine_tuning

Veuillez signaler toute erreur ou tout point peu clair. c'est tout

Recommended Posts

J'ai essayé le réglage fin de CNN avec Resnet
J'ai essayé fp-growth avec python
J'ai essayé de gratter avec Python
J'ai essayé Learning-to-Rank avec Elasticsearch!
J'ai essayé le clustering avec PyCaret
J'ai essayé gRPC avec Python
J'ai essayé de mettre en œuvre le chapeau de regroupement de Harry Potter avec CNN
J'ai essayé de résumer des phrases avec summpy
J'ai essayé l'apprentissage automatique avec liblinear
J'ai essayé webScraping avec python.
J'ai essayé de déplacer de la nourriture avec SinGAN
J'ai essayé d'implémenter DeepPose avec PyTorch
Divers réglages fins avec Mobilenet v2
J'ai essayé d'exécuter prolog avec python 3.8.2.
J'ai essayé la communication SMTP avec Python
J'ai essayé la génération de phrases avec GPT-2
J'ai essayé d'apprendre LightGBM avec Yellowbrick
J'ai essayé la reconnaissance faciale avec OpenCV
J'ai essayé d'envoyer un SMS avec Twilio
J'ai essayé linebot avec flacon (anaconda) + heroku
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de commencer avec Hy
J'ai essayé l'analyse factorielle avec des données Titanic!
J'ai essayé d'apprendre avec le Titanic de Kaggle (kaggle②)
J'ai essayé le rendu non réaliste avec Python + opencv
J'ai essayé un langage fonctionnel avec Python
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé de jouer avec l'image avec Pillow
J'ai essayé de résoudre TSP avec QAOA
J'ai essayé la reconnaissance d'image simple avec Jupyter
J'ai essayé le traitement du langage naturel avec des transformateurs.
# J'ai essayé quelque chose comme Vlookup avec Python # 2
765 J'ai essayé d'identifier les trois familles professionnelles par CNN (avec Chainer 2.0.0)
J'ai essayé la reconnaissance manuscrite des caractères des runes avec scikit-learn
J'ai essayé de prédire l'année prochaine avec l'IA
J'ai essayé de "lisser" l'image avec Python + OpenCV
J'ai essayé des centaines de millions de SQLite avec python
J'ai essayé d'utiliser lightGBM, xg boost avec Boruta
J'ai essayé la reconnaissance d'image de CIFAR-10 avec Keras-Learning-
J'ai essayé d'apprendre le fonctionnement logique avec TF Learn
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé de "différencier" l'image avec Python + OpenCV
J'ai essayé de gratter
J'ai essayé "License OCR" avec l'API Google Vision
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé PyQ
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé Flask avec des conteneurs distants de VS Code
J'ai essayé la décomposition matricielle non négative (NMF) avec TensorFlow
J'ai essayé L-Chika avec Razpai 4 (édition Python)
J'ai essayé le traitement de boucle de type Python avec BigQuery Scripting
J'ai essayé de jouer en connectant PartiQL et MongoDB
J'ai essayé d'analyser les principaux composants avec les données du Titanic!
J'ai essayé la différenciation jacobienne et partielle avec python
J'ai essayé d'obtenir des données CloudWatch avec Python
J'ai essayé d'utiliser mecab avec python2.7, ruby2.3, php7