Comment créer des données à mettre dans CNN (Chainer)

J'ai essayé de créer mon propre modèle en apprenant dans le hackathon, mais il a fallu du temps pour savoir comment créer les données à mettre dans CNN, alors mémo

from PIL import Image
import numpy as np
import glob
import random

def load_image():
    filepaths = glob.glob('data/*.png')

    datasets = []
    for filepath in filepaths:
        img = Image.open(filepath).convert('L')  #Charge avec oreiller.'L'Signifie une échelle de gris
        img = img.resize((32, 32)) #Redimensionné à 32x32x
        label = int(filepath.split('/')[-1].split('_')[0]) #étiquette(Entier supérieur ou égal à 0) (Dans mon cas, je mets souvent un nom d'étiquette au début du nom de fichier.)

        x = np.array(img, dtype=np.float32)
        x = x.reshape(1,32,32) # (Canal, hauteur, largeur)
        t = np.array(label, dtype=np.int32) 

        datasets.append((x,t)) #Appuyez sur x et t pour lister

    random.shuffle(datasets) #mélanger
    train = datasets[:1000] #Les mille premiers pour apprendre
    test = datasets[1000:1100] #Pour les tests du 1000 au 1100
    return train, test


def main(): #Ci-dessous, reportez-vous au cifer10 de l'exemple de chainer

    class_labels = 10
    train, test = load_image()
    
    model = L.Classifier(models.VGG.VGG(class_labels))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    optimizer = chainer.optimizers.MomentumSGD(args.learnrate)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)
    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(TestModeEvaluator(test_iter, model, device=args.gpu))

    # Reduce the learning rate by half every 25 epochs.
    trainer.extend(extensions.ExponentialShift('lr', 0.5),
                   trigger=(25, 'epoch'))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot at each epoch
    trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()

Recommended Posts

Comment créer des données à mettre dans CNN (Chainer)
Essayez de mettre des données dans MongoDB
[Itertools.permutations] Comment créer une séquence en Python
Comment créer une grande quantité de données de test dans MySQL? ??
Comment créer des exemples de données CSV avec hypothèse
Comment créer un fichier JSON en Python
Comment exécuter CNN en notation système 1 avec Tensorflow 2
Comment lire les données de séries chronologiques dans PyTorch
Comment créer une API Rest dans Django
La première étape de l'analyse du journal (comment formater et mettre les données du journal dans Pandas)
Comment appliquer des marqueurs uniquement à des données spécifiques avec matplotlib
Comment créer rapidement des exemples de données pour un tableau pendant le codage
Comment créer un téléchargeur d'image avec Bottle (Python)
[Linux] Comment mettre votre IP dans une variable
Comment développer en Python
Comment gérer les trames de données
Comment créer et utiliser des bibliothèques statiques / dynamiques en langage C
Comment obtenir un aperçu de vos données dans Pandas
Comment créer une trame de données et jouer avec des éléments avec des pandas
Compagnon de science des données en python, comment spécifier des éléments dans les pandas
[Python] Comment faire PCA avec Python
Comment gérer une session dans SQLAlchemy
Comment mettre un lien symbolique
Comment lire les données de la sous-région e-Stat
Comment utiliser les classes dans Theano
Comment écrire sobrement avec des pandas
Comment collecter des images en Python
Comment mettre à jour Spyder dans Anaconda
Comment utiliser SQLite en Python
Comment gérer les données déséquilibrées
Comment créer un pont virtuel
Comment convertir 0,5 en 1056964608 en un seul coup
Comment créer / supprimer des liens symboliques
Malentendu sur la façon de connecter CNN
Comment refléter CSS dans Django
Comment tuer des processus en vrac
Comment utiliser Mysql avec python
Comment augmenter les données avec PyTorch
Comment envelopper C en Python
Comment utiliser ChemSpider en Python
Comment créer un Dockerfile (basique)
Comment utiliser PubChem avec Python
Comment exécuter du code TensorFlow 1.0 en 2.0
Comment gérer le japonais avec Python
Comment créer un fichier de configuration
Comment se connecter à Docker + NGINX
Comment collecter des données d'apprentissage automatique
Comment appeler PyTorch dans Julia
Pour utiliser python, mettez pyenv sur macOS avec PyCall
Comment visualiser où se produit une mauvaise classification dans la classification de l'analyse des données
Comment utiliser les colonnes calculées dans CASTable
Comment créer un clone depuis Github
[Introduction à Python] Comment utiliser la classe en Python?
Comment supprimer l'erreur d'affichage dans matplotlib
Comment collecter des données Twitter sans programmation
Comment définir dynamiquement des variables en Python