[kotlin] Image classification on android (Pytorch Mobile)

PyTorch Mobile It came out around October of last year (2019). Machine learning was possible with android ios such as Tensolflow Lite, but finally for mobile has appeared from pytorch 1.3. It's the best from the side of using pytorch rather than tensorflow! It is available on android ios as well as tensorflow Lite.

Click here for details PyTorch Mobile official website: https://pytorch.org/mobile/home/

From the official website

What to do this time

Do the tutorial introduced on the official website. Write in Kotlin! Classify images using resNet's trained model. (Inference only)

github is posted https://github.com/SY-BETA/PyTorchMobile

Like this ↓ A simple one that only displays the images to be classified, the top two classification results, and their scores. (What is Canis lupus?)

Things necessary

--python execution environment (I did it with jupyter notebook) --pytorch, torchVision (latest version recommended)

Only this

Download ResNet model

First, create a new project in android studio. Create an assets folder in the project. (You can do it by right-clicking the app on the left side of the UI-> New-> Folder-> assets folder) After creating it, execute the following python code in the same hierarchy as the app folder of the project

createModel.py


import torch
import torchvision

#Use resnet model
model = torchvision.models.resnet18(pretrained=True)
#Inference mode
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/resnet.pt")

If it can be executed successfully, a file called resnet.pt will be added to the assets folder created earlier.

Save the following sample images in the assets folder and drawable folder with the name ʻimage.jpg` image.jpg

Implementation

Dependencies

Added the following to gradle (as of January 4, 2020)

dependencies {
    implementation 'org.pytorch:pytorch_android:1.3.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
}

Make a layout with android studio

Create layout appropriately Layout with only 1 image and 6 texts vertically

activity_main.xml


<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <TextView
        android:id="@+id/textView"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="Input"
        android:textSize="30sp"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />

    <ImageView
        android:id="@+id/imageView"
        android:layout_width="wrap_content"
        android:layout_height="230dp"
        android:scaleType="fitCenter"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/textView"
        app:srcCompat="@drawable/image" />

    <TextView
        android:id="@+id/textView2"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="Result"
        android:textSize="30sp"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/imageView" />

    <TextView
        android:id="@+id/result1Score"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="32dp"
        android:text="TextView"
        android:textSize="18sp"
        app:layout_constraintBottom_toTopOf="@+id/result1Class"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/textView2" />

    <TextView
        android:id="@+id/result1Class"
        android:layout_width="250dp"
        android:layout_height="wrap_content"
        android:layout_marginStart="40dp"
        android:layout_marginTop="8dp"
        android:layout_marginEnd="40dp"
        android:gravity="center"
        android:text="TextView"
        android:textSize="18sp"
        app:layout_constraintBottom_toTopOf="@+id/result2Score"
        app:layout_constraintEnd_toEndOf="@+id/result1Score"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toStartOf="@+id/result1Score"
        app:layout_constraintTop_toBottomOf="@+id/result1Score" />

    <TextView
        android:id="@+id/result2Score"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="24dp"
        android:text="TextView"
        android:textSize="18sp"
        app:layout_constraintBottom_toTopOf="@+id/result2Class"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/result1Class"
        app:layout_constraintVertical_bias="0.94" />

    <TextView
        android:id="@+id/result2Class"
        android:layout_width="250dp"
        android:layout_height="wrap_content"
        android:layout_marginStart="40dp"
        android:layout_marginTop="8dp"
        android:layout_marginEnd="40dp"
        android:layout_marginBottom="32dp"
        android:gravity="center"
        android:text="TextView"
        android:textSize="18sp"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintEnd_toEndOf="@+id/result2Score"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toStartOf="@+id/result2Score"
        app:layout_constraintTop_toBottomOf="@+id/result2Score" />
</androidx.constraintlayout.widget.ConstraintLayout>

Model loading

Load the resnet.pt created earlier

MainActivity.kt


 override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        ////Function to get the path from the asset file
        fun assetFilePath(context: Context, assetName: String): String {
            val file = File(context.filesDir, assetName)
            if (file.exists() && file.length() > 0) {
                return file.absolutePath
            }
            context.assets.open(assetName).use { inputStream ->
                FileOutputStream(file).use { outputStream ->
                    val buffer = ByteArray(4 * 1024)
                    var read: Int
                    while (inputStream.read(buffer).also { read = it } != -1) {
                        outputStream.write(buffer, 0, read)
                    }
                    outputStream.flush()
                }
                return file.absolutePath
            }
        }

        ///Load models and images
        ///Load serialized model
        val bitmap = BitmapFactory.decodeStream(assets.open("image.jpg "))
        val module = Module.load(assetFilePath(this, "resnet.pt"))
    }

Note that loading images and models from the assets folder can be quite cumbersome.

inference

Input a sample image using the module added to dependencies and resnet and output the result

MainActivity.kt


        ///Convert to tensor
        val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
            bitmap,
            TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB
        )

        ///Reasoning and its consequences
        ///Forward propagation
        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
        val scores = outputTensor.dataAsFloatArray

Inference result

Extract the higher score

MainActivity.kt


       ///Variable to store score
        var maxScore: Float = 0F
        var maxScoreIdx = -1
        var maxSecondScore: Float = 0F
        var maxSecondScoreIdx = -1

        ///Take the top two with the highest score
        for (i in scores.indices) {
            if (scores[i] > maxScore) {
                maxSecondScore = maxScore
                maxSecondScoreIdx = maxScoreIdx
                maxScore = scores[i]
                maxScoreIdx = i
            }
        }

Classification class

The name of the class to classify Omitted because it is very long (It is an imageNet 1000 class classification) Since it is posted on github, please copy the contents of ʻImageNetClasses.kt`

github Class name list (ImageNetClasses.kt)

ImageNetClasses.kt


class ImageNetClasses {
    var IMAGENET_CLASSES = arrayOf(
        "tench, Tinca tinca",
        "goldfish, Carassius auratus",
      //~~~~~~~~~~~~~~Abbreviation(Please copy from github)~~~~~~~~~~~~~~~~//
        "toilet tissue, toilet paper, bathroom tissue"
    )
}

View results

Get the class name inferred from the index Finally, display the inference result on the layout

MainActivity.kt


        ///Get the class name classified from the index
        val className = ImageNetClasses().IMAGENET_CLASSES[maxScoreIdx]
        val className2 = ImageNetClasses().IMAGENET_CLASSES[maxSecondScoreIdx]
        
        result1Score.text = "score: $maxScore"
        result1Class.text = "Classification result:$className"
        result2Score.text = "score:$maxSecondScore"
        result2Class.text = "Classification result:$className2"

Done! !! If you build it, you should get a screen like the beginning. Please put in various pictures and play with them.

end

The library is convenient. What is image classification possible with just this? I felt that conversion to tensor was a little stuck, but now I can use it for android even with pytorch. As an aside, at first the version of pytorch was not the latest and I got an error when loading the model and I could not do it at all, and I was quite addicted to getting the path of the assets folder.

Recommended Posts

[kotlin] Image classification on android (Pytorch Mobile)
[PyTorch] Image classification of CIFAR-10
[kotlin] Create a real-time image recognition app on android
Create an image recognition application that discriminates the numbers written on the screen on android (PyTorch Mobile) [Android implementation]
Try to infer using a linear regression model on android [PyTorch Mobile]
[PyTorch] Tutorial (Japanese version) ④ ~ TRAINING A CLASSIFIER (image classification) ~
Image classification with self-made neural network by Keras and PyTorch
Real-time image recognition on mobile devices with TensorFlow learning model