[PyTorch] Classification des images du CIFAR-10

Dans cet article, nous classerons les images de CIFAR-10 à l'aide de PyTorch. Suivez le Tutoriel officiel avec des commentaires. De plus, Python et l'apprentissage automatique sont des super débutants.

Qu'est-ce que CIFAR-10?

Un jeu de données d'image à 10 étiquettes largement utilisé dans le domaine de l'apprentissage automatique. airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck 10 étiquettes sont disponibles.

environnement

Installez PyTorch

Site officiel émettra une commande d'installation en fonction de chaque environnement. Puisque je suis un macOS, exécutez ce qui suit pour installer.

pip install torch torchvision

Mettre en œuvre CNN

Importez les bibliothèques requises

#Importer NumPy, Matplotlib, PyTorch
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

Télécharger CIFAR-10

#ToTensor: Image en échelle de gris (RVB 0)~255 à 0~Normaliser à la plage de 1), Normaliser: valeur Z (moyenne RVB et écart type à 0).Normaliser avec 5)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#Télécharger les données d'entraînement
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

#Télécharger les données de test
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)

Découvrez CIFAR-10

Vérifiez les données

#Ensemble de données d'entraînement: 50000 images RVB avec 32 pixels de hauteur et de largeur
print(trainset.data.shape)
(50000, 32, 32, 3)

#Jeu de données de test: 10000 images RVB avec 32 pixels de hauteur et de largeur
print(testset.data.shape)
(10000, 32, 32, 3)

#Consultez la liste des cours
print(trainset.classes)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

#Les classes sont souvent utilisées, alors gardez-les séparément
classes = trainset.classes

Dans le document officiel, ** avion a été redéfini comme avion ** et ** automobile a été redéfini comme voiture **. Pourquoi?

Afficher l'image

#Essayez d'afficher l'image téléchargée
def imshow(img):
    #Dénormaliser
    img = img / 2 + 0.5
    # torch.Du type Tensor à numpy.Convertir en type ndarray
    print(type(img)) # <class 'torch.Tensor'>
    npimg = img.numpy()
    print(type(npimg))    
    #Convertir la forme de (RVB, vertical, horizontal) à (vertical, horizontal, RVB)
    print(npimg.shape)
    npimg = np.transpose(npimg, (1, 2, 0))
    print(npimg.shape)
    #Afficher l'image
    plt.imshow(npimg)
    plt.show()

dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

Mettre en œuvre le réseau

#Mettre en œuvre CNN
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

Définissez chaque couche avec init et connectez-les avec forward.

Définir la fonction de perte / l'optimiseur

#Entropie croisée
criterion = nn.CrossEntropyLoss()
#Méthode de descente de gradient probabiliste
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Entraîner

#Entraîner
for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        #Erreur de propagation de retour
        loss.backward()
        optimizer.step()
        train_loss = loss.item()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
print('Finished Training')

[1,  2000] loss: 2.164
[1,  4000] loss: 1.863
[1,  6000] loss: 1.683
[1,  8000] loss: 1.603
[1, 10000] loss: 1.525
[1, 12000] loss: 1.470
[2,  2000] loss: 1.415
[2,  4000] loss: 1.369
[2,  6000] loss: 1.363
[2,  8000] loss: 1.333
[2, 10000] loss: 1.314
[2, 12000] loss: 1.317
Finished Training

La valeur moyenne de la perte pour chaque 2000 mini-lot est sortie dans le journal.

Enregistrer le modèle

#Enregistrer le modèle
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

Enregistrez le modèle dans le répertoire courant avec l'extension pth (PyTorch).

Essayez d'utiliser le modèle

#Charger les données de test et afficher l'image et l'étiquette correcte
dataiter = iter(testloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
#Chargez le modèle enregistré et prédisez
net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))

image.png GroundTruth: truck cat airplane ship Predicted: truck horse airplane ship

Vous pouvez voir que les prédictions sont correctes sauf pour chat.

print(outputs)
value, predicted = torch.max(outputs, 1)
print(value)
print(predicted)

tensor([[ 0.7114, -2.2724,  0.1225,  0.9470,  2.1940,  1.8655, -2.6655,  4.1646,
         -1.1001, -1.6991],
        [-2.2453, -4.1017,  1.8291,  3.2079,  1.1242,  3.6712,  1.0010,  1.0489,
         -3.2010, -1.9476],
        [-3.0669, -3.8900,  0.9312,  3.5649,  2.7791,  1.5095,  2.1216,  1.5274,
         -4.3077, -2.2234],
        [-2.0948, -3.4640,  2.4833,  2.6210,  4.0590,  1.8350,  0.4924,  0.7212,
         -3.5043, -2.4212]], grad_fn=<AddmmBackward>)
tensor([4.1646, 3.6712, 3.5649, 4.0590], grad_fn=<MaxBackward0>)
tensor([7, 5, 3, 4])

** torch.max ** renvoie la valeur maximale des sorties.

Tester le modèle

correct = 0
total = 0
#Calculer sans se souvenir du gradient (sans apprentissage)
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 60 %

Vous pouvez voir que le taux de réponse correct pour 10000 données de test est de 60%.

Vous trouverez ci-dessous une note personnelle pour les débutants en Python. Pas bon ** (prédit == étiquettes) .sum (). Item () ** Je ne savais pas comment écrire ceci, donc je vais me déconnecter et vérifier.

print(type((predicted == labels)))
print((predicted == labels).dtype)
print(type((predicted == labels).sum()))
print((predicted == labels).sum())
print((predicted == labels).sum().item())
# <class 'torch.Tensor'>
# torch.bool
# <class 'torch.Tensor'>
# tensor(2)
# 2

Je vois. Comparez chaque élément du tableau et utilisez sum () implémenté dans torch.Tensor pour calculer la valeur totale de true. Ensuite, item () implémenté dans torch.Tensor est utilisé pour faire de la valeur totale une valeur numérique de type int. C'était un peu plus facile à comprendre quand je l'ai vérifié avec numpy.

#Essayez avec numpy
a = np.array([1, 2, 3, 4, 5])
b = np.array([1, 2, 0, 4, 5])
print(type((a == b)))
print((a == b))
print((a == b).sum())
print(type((a == b).sum()))
print((a == b).sum().item())
print(type((a == b).sum().item()))
# <class 'numpy.ndarray'>
# [ True  True False  True  True]
# 4
# <class 'numpy.int64'>
# 4
# <class 'int'>

En regardant Official, vous pouvez utiliser presque la même API que ndarray, donc ** sum () ** et ** item () ** Peut être utilisé. Convaincu.

Voyons le taux de réponse correct pour chaque étiquette

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
for i in range(10):
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of airplane : 72 %
Accuracy of automobile : 66 %
Accuracy of  bird : 38 %
Accuracy of   cat : 58 %
Accuracy of  deer : 60 %
Accuracy of   dog : 29 %
Accuracy of  frog : 73 %
Accuracy of horse : 60 %
Accuracy of  ship : 69 %
Accuracy of truck : 73 %

Est-ce que c'est comme ça dans un tutoriel?

Recommended Posts

[PyTorch] Classification des images du CIFAR-10
Classification CIFAR-10 implémentée dans près de 60 lignes dans PyTorch
Classification multi-étiquette d'images multi-classes avec pytorch
[kotlin] Trier les images sur Android (Pytorch Mobile)
Image de fermeture
[Classification des images] Analyse faciale du chien
J'ai essayé la reconnaissance d'image de CIFAR-10 avec Keras-Learning-
J'ai essayé la reconnaissance d'image de CIFAR-10 avec la reconnaissance d'image Keras-
Juge Yosakoi Naruko par classification d'image de Tensorflow.
Résumé super (concis) de la classification des images par ArcFace
Indice de classification typique
Tutoriel [PyTorch] (version japonaise) ④ ~ FORMATION D'UN CLASSIFICATEUR (classification d'images) ~
J'ai essayé la classification d'image d'AutoGluon
Classification d'images avec un réseau de neurones auto-fabriqué par Keras et PyTorch
CNN (1) pour la classification des images (pour les débutants)
Apprendre avec l'enseignant 1 Principes de base de l'apprentissage avec l'enseignant (classification)
Application de la reconnaissance d'image CNN2
Prédiction de la moyenne Nikkei avec Pytorch 2
Classifier les ensembles de données d'image CIFAR-10 à l'aide de divers modèles d'apprentissage en profondeur
Prédiction de la moyenne Nikkei avec Pytorch
Classification en temps réel de plusieurs objets dans les images de la caméra avec apprentissage en profondeur de Raspberry Pi 3 B + et PyTorch
Je voulais contester la classification du CIFAR-10 en utilisant l'entraîneur de Chainer
Capture d'image de Firefox en utilisant Python
Jugement de l'image rétroéclairée avec OpenCV
Résumé de l'implémentation de base par PyTorch
Extraire les points caractéristiques d'une image
[Détails (?)] Introduction au pytorch ~ CNN de CIFAR10 ~
Classification d'images avec un jeu de données d'images de fond d'oeil grand angle
Prédiction de la moyenne Nikkei avec Pytorch ~ Makuma ~
Traduction japonaise appropriée de pytorch tensor_tutorial
Reconnaissance d'image des fruits avec VGG16
"Classer les déchets par image!" Journal de création d'application jour6 ~ Correction de la structure des répertoires ~