L'histoire d'un débutant en apprentissage profond essayant de classer les guitares avec CNN

Aperçu

Certaines personnes l'ont déjà essayé sur Qiita, mais cela sert également de leur propre étude. J'ai essayé de classer les images de guitare en utilisant CNN (ResNet), donc je l'ai essayé dans le processus, Voici quelques éléments qui peuvent être utiles. (Comme ce n'est pas résumé, c'est un peu sale, mais je posterai aussi le code)

table des matières

À propos de la méthode de classification spécifique

Une image de guitare est obtenue par grattage et prétraitée pour gonfler l'image. En affinant ResNet, qui est une méthode de CNN, en utilisant des images gonflées Je vais essayer de faire de l'apprentissage automatique sans dépenser trop de frais d'apprentissage.

À propos des étiquettes

J'ai choisi les modèles suivants, qui semblent relativement faciles à collecter des images.

À propos du prétraitement

Le premier est de collecter des images. Cette fois, je l'ai récupéré en utilisant iCrawler. Généralement, la plupart d'entre eux sont collectés à partir de la recherche d'images Google, mais à partir du 12 mars 2020, en raison de changements dans les spécifications du côté Google. Cette fois, j'ai collecté des images de Bing car l'outil semble être en panne.

crawling.py


import os

from icrawler.builtin import BingImageCrawler

searching_words = [
                    "Fender Stratocaster",
                    "Fender Telecaster",
                    "Fender Jazzmaster",
                    "Fender Jaguar",
                    "Fender Mustang",
                    "Gibson LesPaul",
                    "Gibson SG",
                    "Gibson FlyingV",
                    "Gibson ES-335",
                    "Acoustic guitar"
                ]
if __name__ == "__main__":
    for word in searching_words:
        if not os.path.isdir('./searched_image/' + word):
            os.makedirs('./searched_image/' + word)
        bing_crawler = BingImageCrawler(storage={ 'root_dir': './searched_image/' + word })
        bing_crawler.crawl(keyword=word, max_num=1000)

Après la collecte, j'ai omis manuellement les images qui sont peu susceptibles d'être utilisées (celles qui ne montrent pas tout le corps de la guitare, celles qui contiennent des lettres, celles qui ont des reflets comme les mains, etc.). En conséquence, nous avons pu collecter environ 100 à 160 images pour chaque étiquette. (J'ai spécifié max_num = 1000 dans la méthode d'exploration, mais il n'a collecté qu'environ 400 feuilles.)

Ensuite, nous pré-traiterons les images collectées. Cette fois, l'image a été tournée de 45 ° et inversée. Le résultat a donc été multiplié par 16 pour atteindre environ 1 600 à 2 000 images pour chaque étiquette.

image_preprocessing.py


import os
import glob

from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split 

#La taille de l'image à compresser
image_size = 224
#Nombre de données d'entraînement
traindata = 1000
#Nombre de données de test
testdata = 300

#Nom du dossier d'entrée
src_dir = './searched_image'
#Nom du dossier de sortie
dst_dir = './input_guitar_data'

#Nom de l'étiquette à identifier
labels = [
                    "Fender Stratocaster",
                    "Fender Telecaster",
                    "Fender Jazzmaster",
                    "Fender Jaguar",
                    "Fender Mustang",
                    "Gibson LesPaul",
                    "Gibson SG",
                    "Gibson FlyingV",
                    "Gibson ES-335",
                    "Acoustic guitar"
                ]
#Chargement des images
for index, label in enumerate(labels):
    files =glob.glob("{}/{}/all/*.jpg ".format(src_dir, label))
        
    #Données converties d'image
    X = []
    #étiquette
    Y = []

    for file in files:
        #Ouvrir l'image
        img = Image.open(file)
        img = img.convert("RGB")
        
        #===================#Convertir en carré#===================#
        width, height = img.size
        #S'il est verticalement long, développez-le horizontalement
        if width < height:
            result = Image.new(img.mode,(height, height),(255, 255, 255))
            result.paste(img, ((height - width) // 2, 0))
        #S'il est horizontalement long, développez-le verticalement
        elif width > height:
            result = Image.new(img.mode,(width, width),(255, 255, 255))
            result.paste(img, (0, (width - height) // 2))
        else:
            result = img

        #Aligner la taille de l'image sur 224x224
        result.resize((image_size, image_size))

        data = np.asarray(result)
        X.append(data)
        Y.append(index)

        #===================#Données gonflées#===================#
        for angle in range(0, 360, 45):
            #rotation
            img_r = result.rotate(angle)
            data = np.asarray(img_r)
            X.append(data)
            Y.append(index)

            #Inverser
            img_t = img_r.transpose(Image.FLIP_LEFT_RIGHT)
            data = np.asarray(img_t)
            X.append(data)
            Y.append(index)
    
    #Normalisation(0~255->0~1)
    X = np.array(X,dtype='float32') / 255.0
    Y = np.array(Y)


    #Fractionner les données pour la vérification des intersections
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=testdata, train_size=traindata)
    xy = (X_train, X_test, y_train, y_test)
    np.save("{}/{}_{}.npy".format(dst_dir, label, index), xy)

Enregistrez les résultats prétraités dans un fichier npy pour chaque étiquette.

À propos de la méthode d'apprentissage

Cette fois, je vais essayer d'apprendre à utiliser ResNet, qui est une méthode typique de CNN. Étant donné que le PC que je possède n'a pas de GPU NVIDIA, si j'essaie de l'entraîner tel quel, cela prendra énormément de temps car il ne sera calculé que par le processeur, alors exécutons et apprenons le code suivant dans l'environnement GPGPU à l'aide de Google Colab. J'ai fait. (Comment utiliser Colab, comment télécharger des fichiers, etc. sont omis)

import gc

import keras
from keras.applications.resnet50 import ResNet50
from keras.models import Sequential, Model
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense, Input
from keras.callbacks import EarlyStopping 
from keras.utils import np_utils
from keras import optimizers

from sklearn.metrics import confusion_matrix

import numpy as np
import matplotlib.pyplot as plt

#Définition de l'étiquette de classe
classes = [
                    "Fender Stratocaster",
                    "Fender Telecaster",
                    "Fender Jazzmaster",
                    "Fender Jaguar",
                    "Fender Mustang",
                    "Gibson LesPaul",
                    "Gibson SG",
                    "Gibson FlyingV",
                    "Gibson ES-335",
                    "Acoustic guitar"
                ]
num_classes = len(classes)

#Taille de l'image à charger
ScaleTo = 224

#Définition de la fonction principale
def main():
    #Lecture des données d'entraînement
    src_dir = '/content/drive/My Drive/Apprentissage automatique/input_guitar_data'

    train_Xs = []
    test_Xs = []
    train_ys = []
    test_ys = []

    for index, class_name in enumerate(classes):
        file = "{}/{}_{}.npy".format(src_dir, class_name, index)
        #Apportez un fichier d'apprentissage séparé
        train_X, test_X, train_y, test_y = np.load(file, allow_pickle=True)

        #Combinez les données en une seule
        train_Xs.append(train_X)
        test_Xs.append(test_X)
        train_ys.append(train_y)
        test_ys.append(test_y)

    #Combinez les données combinées
    X_train = np.concatenate(train_Xs, 0)
    X_test = np.concatenate(test_Xs, 0)
    y_train = np.concatenate(train_ys, 0)
    y_test = np.concatenate(test_ys, 0)

    #Étiquette
    y_train = np_utils.to_categorical(y_train, num_classes)
    y_test = np_utils.to_categorical(y_test, num_classes)


    #Génération de modèle d'apprentissage automatique
    model, history = model_train(X_train, y_train, X_test, y_test)
    model_eval(model, X_test, y_test)
    #Afficher l'historique d'apprentissage
    model_visualization(history)

def model_train(X_train, y_train, X_test, y_test):
    #Charge ResNet 50. Inclure car aucune couche entièrement connectée n'est requise_top=False
    input_tensor = Input(shape=(ScaleTo, ScaleTo, 3))
    resnet50 = ResNet50(include_top=False, weights='imagenet', input_tensor=input_tensor)

    #Créer une couche entièrement connectée
    top_model = Sequential()
    top_model.add(Flatten(input_shape=resnet50.output_shape[1:]))
    top_model.add(Dense(256, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(num_classes, activation='softmax'))

    #Créez un modèle en combinant ResNet50 et une couche entièrement connectée
    resnet50_model = Model(input=resnet50.input, output=top_model(resnet50.output))

    """
    #Correction de certains poids de ResNet50
    for layer in resnet50_model.layers[:100]:
        layer.trainable = False
    """

    #Spécifier la classification multi-classes
    resnet50_model.compile(loss='categorical_crossentropy',
            optimizer=optimizers.SGD(lr=1e-3, momentum=0.9),
            metrics=['accuracy'])
    resnet50_model.summary()

    #Exécution de l'apprentissage
    early_stopping = EarlyStopping(monitor='val_loss', patience=0, verbose=1) 
    history = resnet50_model.fit(X_train, y_train,
                        batch_size=75,
                        epochs=25, validation_data=(X_test, y_test),
                        callbacks=[early_stopping])
    #Enregistrer le modèle
    resnet50_model.save("/content/drive/My Drive/Apprentissage automatique/guitar_cnn_resnet50.h5")
    
    return resnet50_model, history

def model_eval(model, X_test, y_test):
    scores = model.evaluate(X_test, y_test, verbose=1)
    print("test Loss", scores[0])
    print("test Accuracy", scores[1])
    #Calcul de la matrice de confusion
    predict_classes = model.predict(X_test)
    predict_classes = np.argmax(predict_classes, 1)
    true_classes = np.argmax(y_test, 1)
    print(predict_classes)
    print(true_classes)
    cmx = confusion_matrix(true_classes, predict_classes)
    print(cmx)
    #Effacer le modèle une fois l'inférence terminée
    del model
    keras.backend.clear_session() #← C'est
    gc.collect()

def model_visualization(history):
    #Affichage graphique de la valeur de perte
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

    #Affichage graphique du taux de réponse correct
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()
    
if __name__ == "__main__":
    main()

Cette fois, le résultat de val acc etc. était meilleur si le poids n'était pas fixe, donc le poids de chaque couche est également appris à nouveau. Dans le code, 100 époques sont entraînées, mais en réalité, Early Stopping a mis fin à l'apprentissage à la 5ème époque.

À propos des résultats d'apprentissage

Le résultat est le suivant.

test Loss 0.09369107168481061
test Accuracy 0.9744

Je vais également publier une matrice de confusion.

[[199   0   1   0   0   0   0   0   0   0]
 [  0 200   0   0   0   0   0   0   0   0]
 [  2   5 191   2   0   0   0   0   0   0]
 [  1   0  11 180   6   0   2   0   0   0]
 [  0   2   0   0 198   0   0   0   0   0]
 [  0   0   0   0   0 288   4   0   6   2]
 [  0   2   0   0   0   0 296   0   2   0]
 [  0   0   0   0   0   0   0 300   0   0]
 [  0   0   0   0   0   0   0   0 300   0]
 [  0   0   0   0   0   0   0   1   0 299]]

ダウンロード2.png ダウンロード.png

À la fin d'une époque, vous pouvez voir que l'apprentissage a considérablement progressé.

Essayez et jouez

J'essaierai l'inférence basée sur le modèle enregistré. Cette fois, j'ai essayé d'en faire une application web très rudimentaire en utilisant Flask que j'ai touché pour la première fois.

graphing.py


import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

def to_graph(image, labels, predicted):
    #=======#Tracer et enregistrer#=======#
    fig = plt.figure(figsize=(10.24, 5.12))
    fig.subplots_adjust(left=0.2)

    #=======#Écrire un graphique à barres horizontales#=======#
    ax1 = fig.add_subplot(1,2,1)
    ax1.barh(labels, predicted, color='c', align="center")
    ax1.set_yticks(labels)#étiquette de l'axe y
    ax1.set_xticks([])#Supprimer l'étiquette de l'axe x

    #Écrire des nombres dans des graphiques à barres
    for interval, value in zip(range(0,len(labels)), predicted):
        ax1.text(0.02, interval, value, ha='left', va='center')

    #=======#Insérez l'image identifiée#=======#
    ax2 = fig.add_subplot(1,2,2)
    ax2.imshow(image)
    ax2.axis('off')

    return fig

def expand_to_square(input_file):
    """Convertir une image rectangulaire en carré
    input_file:Nom de fichier à convertir
Valeur de retour:Image convertie
    """
    img = Image.open(input_file)
    img = img.convert("RGB")
    
    width, height = img.size
    #S'il est verticalement long, développez-le horizontalement
    if width < height:
        result = Image.new(img.mode,(height, height),(255, 255, 255))
        result.paste(img, ((height - width) // 2, 0))
    #S'il est horizontalement long, développez-le verticalement
    elif width > height:
        result = Image.new(img.mode,(width, width),(255, 255, 255))
        result.paste(img, (0, (width - height) // 2))
    else:
        result = img
    
    return result 

predict_file.py


predict_file.py
import io
import gc

from flask import Flask, request, redirect, url_for
from flask import flash, render_template, make_response

from keras.models import Sequential, load_model
from keras.applications.resnet50 import decode_predictions
import keras

import numpy as np
from PIL import Image
from matplotlib.backends.backend_agg import FigureCanvasAgg

import graphing

classes = [
            "Fender Stratocaster",
            "Fender Telecaster",
            "Fender Jazzmaster",
            "Fender Jaguar",
            "Fender Mustang",
            "Gibson LesPaul",
            "Gibson SG",
            "Gibson FlyingV",
            "Gibson ES-335",
            "Acoustic guitar"
            ]
num_classes = len(classes)
image_size = 224
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'gif'])


app = Flask(__name__)

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.',1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/', methods=['GET', 'POST'])
def upload_file():
    if request.method == 'POST':
        if 'file' not in request.files:
            flash('Pas de fichier')
            return redirect(request.url)
        file = request.files['file']

        if file.filename == '':
            flash('Pas de fichier')
            return redirect(request.url)

        if file and allowed_file(file.filename):
            virtual_output = io.BytesIO()
            file.save(virtual_output)
            filepath = virtual_output

            model = load_model('./cnn_model/guitar_cnn_resnet50.h5')

            #Convertir l'image en carré
            image = graphing.expand_to_square(filepath)
            image = image.convert('RGB')
            #Aligner la taille de l'image sur 224x224
            image = image.resize((image_size, image_size))
            #Passer de l'image au tableau numpy et effectuer la normalisation
            data = np.asarray(image) / 255.0
            #Augmenter les dimensions du tableau(3D->4 dimensions)
            data = np.expand_dims(data, axis=0)
            #Faire des inférences à l'aide du modèle appris
            result = model.predict(data)[0]
            
            #Dessinez le résultat de l'inférence et l'image déduite dans un graphique
            fig = graphing.to_graph(image, classes, result)
            canvas = FigureCanvasAgg(fig)
            png_output = io.BytesIO()
            canvas.print_png(png_output)
            data = png_output.getvalue()

            response = make_response(data)
            response.headers['Content-Type'] = 'image/png'
            response.headers['Content-Length'] = len(data)

            #Effacer le modèle une fois l'inférence terminée
            del model
            keras.backend.clear_session()
            gc.collect()

            return response
    return '''
    <!doctype html>
    <html>
        <head>
            <meta charset="UTF-8">
            <title>Téléchargeons le fichier et jugeons</title>
        </head>
        <body>
            <h1>Téléchargez le fichier et jugez!</h1>
            <form method = post enctype = multipart/form-data>
                <p><input type=file name=file>
                <input type=submit value=Upload>
            </form>
        </body>
    </html>
    '''

À propos, si vous répétez plusieurs fois l'apprentissage et l'inférence sur Keras, les données semblent déborder dans la mémoire et il semble que vous deviez les effacer explicitement dans le code. (De même sur colab)

URL de référence ↓ Correction du problème d'augmentation de l'utilisation de la mémoire lors de l'apprentissage répété avec keras

En outre, je publierai le code source de l'application Web que j'ai réellement créée. ↓ Application Web de classification de guitare

Essayez et jouez

Je l'ai essayé avec mon propre instrument.

D'abord du maître du jazz ジャズマスター判定.png Il réagit également à Jaguar, qui présente de nombreuses similitudes. Cependant, s'il s'agit d'une autre image obtenue à partir d'un autre réseau, elle peut être jugée comme 99% Jazz Master, on ne peut donc pas dire que la précision de classification est mauvaise.

Puis Stratocaster ストラトキャスター判定.png Il était presque certainement déterminé à être une Stratocaster. Il semble qu'il n'y ait pas de problème particulier même si le contraste est légèrement sombre.

Alors, que se passe-t-il si vous les laissez déterminer quelle base ils n'ont pas formée? Je l'ai essayé avec mon type de jazz bass. ジャズベース判定.png Il n'est pas clair qu'il sera jugé comme Mustang, mais je crains que la probabilité de SG soit également élevée. Il semble que la partie tsuno ne soit pas similaire ...?

Résumé

Cette fois, en affinant ResNet, qui est une méthode de CNN, nous avons pu créer un classificateur qui est relativement facile à créer mais qui a une grande précision. Cependant, certains machine learning, tels que CNN, sont difficiles à expliquer pourquoi les résultats étaient ainsi. Donc, si j'ai le temps, je vais essayer des méthodes de visualisation telles que Grad-CAM à l'avenir.

c'est tout.

Recommended Posts

L'histoire d'un débutant en apprentissage profond essayant de classer les guitares avec CNN
Introduction à l'apprentissage profond ~ Expérience CNN ~
[Windows] L'histoire d'un débutant qui tombe sur le décor de PATH d'Anaconda.
L'histoire d'essayer de reconnecter le client
Une histoire à laquelle j'étais accro à essayer d'installer LightFM sur Amazon Linux
Une histoire sur un débutant essayant de configurer CentOS 8 (mémo de procédure)
J'ai essayé l'histoire courante de l'utilisation du Deep Learning pour prédire la moyenne Nikkei
L'histoire de la tentative de pousser SSH_AUTH_SOCK obsolète avec LD_PRELOAD à l'écran
Une histoire où un débutant est coincé en essayant de créer un environnement de plug-in vim 8.2 + python 3.8.2 + lua sur Ubuntu 18.04.4 LTS
Une histoire d'essayer d'installer uwsgi sur une instance EC2 et d'échouer
L'histoire de l'apprentissage profond avec TPU
De rien sur Ubuntu 18.04 à la configuration d'un environnement Deep Learning sur Tensor
Un mémorandum d'étude et de mise en œuvre du Deep Learning
Histoire d'essayer d'utiliser Tensorboard avec Pytorch
Créez un environnement python pour apprendre la théorie et la mise en œuvre de l'apprentissage profond
L'histoire d'un technicien de haut niveau essayant de prédire la survie du Titanic
Une histoire bloquée lors de la tentative de mise à niveau de la version Python avec GCE
Je ne trouve pas l'horloge tsc! ?? L'histoire d'essayer d'écrire un patch de noyau
Une histoire d'essais et d'erreurs essayant de créer un groupe d'utilisateurs dynamique dans Slack
Une histoire sur un débutant Python essayant d'obtenir des résultats de recherche Google à l'aide de l'API
Étapes pour créer rapidement un environnement d'apprentissage en profondeur sur Mac avec TensorFlow et OpenCV
Une histoire sur la tentative d'introduire Linter au milieu d'un projet Python (Flask)
Créer un ensemble de données d'images à utiliser pour la formation
Une histoire de prédiction du taux de change avec Deep Learning
Une histoire d'essayer pyenv, virtualenv et virtualenvwrapper
Learning Deep Forest, un nouveau dispositif d'apprentissage comparable à DNN
Une note d'essayer un simple tutoriel MCMC sur PyMC3
Une histoire sur un débutant Linux mettant Linux sur une tablette Windows
Deep learning 1 Pratique du deep learning
Essayez de créer un réseau de neurones / d'apprentissage en profondeur avec scratch
Record des leçons de l'enfer imposées aux étudiants débutants en Python
[Introduction à AWS] Mémorandum de création d'un serveur Web sur AWS
Comment enregistrer un package dans PyPI (à partir de septembre 2017)
Classifier les ensembles de données d'image CIFAR-10 à l'aide de divers modèles d'apprentissage en profondeur
Technologie qui prend en charge Jupyter: Tralets (histoire d'essayer de déchiffrer)
L'histoire d'un ingénieur directeur de 40 ans qui réussit "Deep Learning for ENGINEER"
Une histoire sur la tentative d'implémentation de variables privées en Python.
J'ai essayé d'écrire dans un modèle de langage profondément appris
Une histoire sur un débutant de GCP essayant de créer un serveur Micra avec GCE
Une histoire dans laquelle l'algorithme est arrivé à une conclusion ridicule en essayant de résoudre correctement le problème du voyageur de commerce
<Cours> Apprentissage en profondeur: Day2 CNN
Deep running 2 Réglage de l'apprentissage profond
Introduction au Deep Learning ~ Règles d'apprentissage ~
Apprentissage par renforcement profond 1 Introduction au renforcement de l'apprentissage
Apprentissage par renforcement profond 2 Mise en œuvre de l'apprentissage par renforcement
Introduction au Deep Learning ~ Rétropropagation ~
Une histoire coincée avec l'installation de la bibliothèque de machine learning JAX
Une histoire d'essayer d'automatiser un chot lorsque vous cuisinez vous-même
J'ai recherché une carte similaire de Hearthstone avec Deep Learning
Une histoire sur la tentative d'exécuter plusieurs versions de Python (édition Mac)
Une commande pour vérifier facilement la vitesse du réseau sur la console
Essayez de faire une stratégie de blackjack en renforçant l'apprentissage ((1) Implémentation du blackjack)
Une histoire sur le fait de vouloir penser à des personnages déformés dans GAE / P
L'histoire de l'échec de la mise à jour de "calendar.day_abbr" sur l'écran d'administration de django
Une histoire d'essayer d'améliorer le processus de test d'un système vieux de 20 ans écrit en C