PyTorch Mobile It came out around October of last year (2019). Machine learning was possible with android ios such as Tensolflow Lite, but finally for mobile has appeared from pytorch 1.3. It's the best from the side of using pytorch rather than tensorflow! It is available on android ios as well as tensorflow Lite.
Click here for details PyTorch Mobile official website: https://pytorch.org/mobile/home/
From the official website
Do the tutorial introduced on the official website. Write in Kotlin! Classify images using resNet's trained model. (Inference only)
github is posted https://github.com/SY-BETA/PyTorchMobile
Like this ↓ A simple one that only displays the images to be classified, the top two classification results, and their scores. (What is Canis lupus?)
--python execution environment (I did it with jupyter notebook) --pytorch, torchVision (latest version recommended)
Only this
First, create a new project in android studio. Create an assets folder in the project. (You can do it by right-clicking the app on the left side of the UI-> New-> Folder-> assets folder) After creating it, execute the following python code in the same hierarchy as the app folder of the project
createModel.py
import torch
import torchvision
#Use resnet model
model = torchvision.models.resnet18(pretrained=True)
#Inference mode
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/resnet.pt")
If it can be executed successfully, a file called resnet.pt
will be added to the assets folder created earlier.
Save the following sample images in the assets folder and drawable folder with the name ʻimage.jpg`
Added the following to gradle (as of January 4, 2020)
dependencies {
implementation 'org.pytorch:pytorch_android:1.3.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
}
Create layout appropriately Layout with only 1 image and 6 texts vertically
activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Input"
android:textSize="30sp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<ImageView
android:id="@+id/imageView"
android:layout_width="wrap_content"
android:layout_height="230dp"
android:scaleType="fitCenter"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/textView"
app:srcCompat="@drawable/image" />
<TextView
android:id="@+id/textView2"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Result"
android:textSize="30sp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/imageView" />
<TextView
android:id="@+id/result1Score"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="32dp"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toTopOf="@+id/result1Class"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/textView2" />
<TextView
android:id="@+id/result1Class"
android:layout_width="250dp"
android:layout_height="wrap_content"
android:layout_marginStart="40dp"
android:layout_marginTop="8dp"
android:layout_marginEnd="40dp"
android:gravity="center"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toTopOf="@+id/result2Score"
app:layout_constraintEnd_toEndOf="@+id/result1Score"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="@+id/result1Score"
app:layout_constraintTop_toBottomOf="@+id/result1Score" />
<TextView
android:id="@+id/result2Score"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="24dp"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toTopOf="@+id/result2Class"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/result1Class"
app:layout_constraintVertical_bias="0.94" />
<TextView
android:id="@+id/result2Class"
android:layout_width="250dp"
android:layout_height="wrap_content"
android:layout_marginStart="40dp"
android:layout_marginTop="8dp"
android:layout_marginEnd="40dp"
android:layout_marginBottom="32dp"
android:gravity="center"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="@+id/result2Score"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="@+id/result2Score"
app:layout_constraintTop_toBottomOf="@+id/result2Score" />
</androidx.constraintlayout.widget.ConstraintLayout>
Load the resnet.pt
created earlier
MainActivity.kt
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
////Function to get the path from the asset file
fun assetFilePath(context: Context, assetName: String): String {
val file = File(context.filesDir, assetName)
if (file.exists() && file.length() > 0) {
return file.absolutePath
}
context.assets.open(assetName).use { inputStream ->
FileOutputStream(file).use { outputStream ->
val buffer = ByteArray(4 * 1024)
var read: Int
while (inputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
outputStream.flush()
}
return file.absolutePath
}
}
///Load models and images
///Load serialized model
val bitmap = BitmapFactory.decodeStream(assets.open("image.jpg "))
val module = Module.load(assetFilePath(this, "resnet.pt"))
}
Note that loading images and models from the assets folder can be quite cumbersome.
Input a sample image using the module added to dependencies and resnet and output the result
MainActivity.kt
///Convert to tensor
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB
)
///Reasoning and its consequences
///Forward propagation
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
val scores = outputTensor.dataAsFloatArray
Extract the higher score
MainActivity.kt
///Variable to store score
var maxScore: Float = 0F
var maxScoreIdx = -1
var maxSecondScore: Float = 0F
var maxSecondScoreIdx = -1
///Take the top two with the highest score
for (i in scores.indices) {
if (scores[i] > maxScore) {
maxSecondScore = maxScore
maxSecondScoreIdx = maxScoreIdx
maxScore = scores[i]
maxScoreIdx = i
}
}
The name of the class to classify Omitted because it is very long (It is an imageNet 1000 class classification) Since it is posted on github, please copy the contents of ʻImageNetClasses.kt`
github Class name list (ImageNetClasses.kt)
ImageNetClasses.kt
class ImageNetClasses {
var IMAGENET_CLASSES = arrayOf(
"tench, Tinca tinca",
"goldfish, Carassius auratus",
//~~~~~~~~~~~~~~Abbreviation(Please copy from github)~~~~~~~~~~~~~~~~//
"toilet tissue, toilet paper, bathroom tissue"
)
}
Get the class name inferred from the index Finally, display the inference result on the layout
MainActivity.kt
///Get the class name classified from the index
val className = ImageNetClasses().IMAGENET_CLASSES[maxScoreIdx]
val className2 = ImageNetClasses().IMAGENET_CLASSES[maxSecondScoreIdx]
result1Score.text = "score: $maxScore"
result1Class.text = "Classification result:$className"
result2Score.text = "score:$maxSecondScore"
result2Class.text = "Classification result:$className2"
Done! !! If you build it, you should get a screen like the beginning. Please put in various pictures and play with them.
The library is convenient. What is image classification possible with just this? I felt that conversion to tensor was a little stuck, but now I can use it for android even with pytorch. As an aside, at first the version of pytorch was not the latest and I got an error when loading the model and I could not do it at all, and I was quite addicted to getting the path of the assets folder.
Recommended Posts