I tried to make Kana's handwriting recognition Part 1/3 First from MNIST

Overview

I entered kana in the GUI and tried to detect characters by a model created by training in advance by machine learning.

First, check the feel and accuracy of CNN with MNIST, then give the actual kana data to learn, and finally link it with the GUI.

Next time (2/3): https://qiita.com/tfull_tf/items/968bdb8f24f80d57617e Next time (3/3): https://qiita.com/tfull_tf/items/d9fe3ab6c1e47d1b2e1e

The entire code can be found at: https://github.com/tfull/character_recognition

Model building with MNIST

Build your own model and run train, test on the popular handwritten digit dataset MNIST to see how accurate it is.

Since MNIST is 28x28 grayscale data, enter it as (channel, width, height) = (1, 28, 28). Since the numbers are 0 to 9, there are 10 classification destinations, and 10 probabilities are output.

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(0.3)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(12 * 12 * 32, 128)
        self.relu3 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.linear2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu3(x)
        x = self.dropout2(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x

It is converted to one dimension and passed through two fully connected layers via two convolutional layers and a subsequent pooling layer. The activation function is ReLU, and the outline of the model is to insert a dropout layer to prevent overfitting in the middle.

Data acquisition

import torchvision

download_flag = not os.path.exists(data_directory + "/mnist")

mnist_train = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = True,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

mnist_test = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = False,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

Save the MNIST data locally and use it. Define data_directory so that it will be downloaded if it doesn't exist. By doing so, I made sure to download only the first time.

Preparation for learning

import torch
import torch.optim as optim

train_loader = torch.utils.data.DataLoader(mnist_train,  batch_size = 100,  shuffle = True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 1000, shuffle = False)

model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

Use the DataLoader to retrieve the data in sequence.

Set the model, error function, and optimization algorithm. We adopted the cross entropy error, Adam.

Training

n_epoch = 2

model.train()

for i_epoch in range(n_epoch):
    for i_batch, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        print("epoch: {}, train: {}, loss: {}".format(i_epoch + 1, i_batch + 1, loss.item()))

A series of learning operations are performed in a loop, in which image data (inputs) is given to the model, the output (output) and the correct answer data (labels) are compared to obtain the error, and backpropagation is performed. I will.

I think that giving each data once is not enough for learning, so I set the number of epochs (n_epoch) to 2 and give each data n_epoch times for learning. The number of epochs is my experience, but I think that about 2 to 3 is just right. I think it depends on the number of data.

Evaluation

correct_count = 0
record_count = 0

model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, prediction = torch.max(outputs.data, 1)
        judge = prediction == labels
        correct_count += int(judge.sum())
        record_count += len(judge)

print("Accuracy: {:.2f}%".format(correct_count / record_count * 100))

We input the numerical data (inputs) of the image into the model, and the highest of the 10 probabilities that came out is the prediction result (prediction). It compares whether it matches the correct answer data (labels), returns True / False, and calculates the number of True (correct_count) with respect to the total number (record_count) to obtain the correct answer rate.

Results and discussion

The result averaged multiple times, about 97%.

I think the value of the correct answer rate is high, but I failed 3 times in 100 times. I think it will be another matter whether humans can tolerate this. However, some MNIST image data has dirty characters that are difficult for humans to distinguish, so in that sense, a 3% mistake may be unavoidable.

MNIST has 10 choices from 0 to 9, but since there are more than 100 hiragana and katakana for kana, it will be difficult to classify and you will have to be prepared for a lower percentage of correct answers.

Recommended Posts

I tried to make Kana's handwriting recognition Part 1/3 First from MNIST
I tried to make Kana's handwriting recognition Part 2/3 Data creation and learning
I tried to make Kana's handwriting recognition Part 3/3 Cooperation with GUI using Tkinter
I tried to implement Perceptron Part 1 [Deep Learning from scratch]
I tried to make a Web API
I tried handwriting recognition of runes with scikit-learn
I want to make fits from my head
I tried to make AI for Smash Bros.
I want to make C ++ code from Python code!
I tried to make a ○ ✕ game using TensorFlow
I tried to make a simple image recognition API with Fast API and Tensorflow
I tried to make a "fucking big literary converter"
Continuation ・ I tried to make Slackbot after studying Python3
I tried to debug.
I tried to paste
I tried to erase the negative part of Meros
I tried to make an OCR application with PySimpleGUI
[Deep Learning from scratch] I tried to explain Dropout
I tried to make a generator that generates a C # container class from CSV with Python
[First COTOHA API] I tried to summarize the old story
I tried to create API list.csv in Python from swagger.yaml
I tried to make various "dummy data" with Python faker
I tried face recognition from the video (OpenCV: python version)
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to make a stopwatch using tkinter in python
I tried handwriting recognition of runes with CNN using Keras
I tried to make GUI tic-tac-toe with Python and Tkinter
I tried changing the python script from 2.7.11 to 3.6.0 on windows10
I tried to get various information from the codeforces API
I tried to make a simple text editor using PyQt
I tried to get data from AS / 400 quickly using pypyodbc
I tried to learn PredNet
I tried to organize SVM.
I tried face recognition using Face ++
I tried to implement PCANet
I tried to reintroduce Linux
I tried to introduce Pylint
I tried to summarize SparseMatrix
I tried to touch jupyter
I tried to implement StarGAN (1)
[First scraping] I tried to make a VIP character of Smash Bros. [Beautiful Soup] [Data analysis]
I tried to make a system that fetches only deleted tweets
I tried to make deep learning scalable with Spark × Keras × Docker
[Python] I tried to implement stable sorting, so make a note
[Introduction to simulation] I tried playing by simulating corona infection ♬ Part 2
I tried to make a regular expression of "time" using Python
[3rd] I tried to make a certain authenticator-like tool with python
I tried to make a regular expression of "date" using Python
I tried to make a periodical process with Selenium and Python
I tried to make a 2channel post notification application with Python
I tried to make an analysis base of 5 patterns in 3 years
I want to make a parameter list from CloudFormation code (yaml)
I tried to make a todo application using bottle with python
[4th] I tried to make a certain authenticator-like tool with python
[Python] Simple Japanese ⇒ I tried to make an English translation tool
I tried to cut out a still image from the video
[1st] I tried to make a certain authenticator-like tool with python
I tried to extract players and skill names from sports articles
I tried to make a strange quote for Jojo with LSTM
I tried to get data from AS / 400 quickly using pypyodbc Preparation 1
I tried to make an image similarity function with Python + OpenCV