Create a classification model from your own image dataset and share how to move it in real time using your iOS or Android camera.
--google colaboratory (runtime: GPU) (TensorFlow 1.15) (Google Chrome)
In this article, we will use retrain.py to create an image classification model [^ 1]. [^ 1]: retrain.py is migrating to make_image_classifier. If you use make_image_classifier, you can convert it to tflite at once with learning, and it seems that you do not need to rewrite swift 224 to 299.
You can create a model with retrain.py, so you can build an environment with the curl
command. You don't have to git clone
.
retrain.If you want to get only py
curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py
Image data can be collected relatively easily by using scraping and image collection tools. I used google-images-download to collect image data [^ 2]. [^ 2]: As of 03/07/2020, google-images-download does not work in some environments. It is believed that the cause is that the Google search algorithm has changed. How to use google-images-download [many articles](https://www.google.com/search?sxsrf=ALeKk02U-SqEjAhMNjmpl4-sUbwkSaevTQ:1583514716818&q=google_images_download&spell=1&sa=X&ved=2ahUKEwig7MSBrIboAhW Since it has been done, I will omit it here. To create a model using retrain.py, make the directory structure as follows.
retrain.py
dataset
|--label_A
| └─ aaa.jpg
| └─ bbb.png
| └─ ccc.jpg
| ⋮
|-- label_B
| └─ ddd.png
| └─ eee.jpg
| └─ fff.png
⋮ ⋮
```
## 1.2 Modeling
After preparing retrain.py and image data, we will actually train and create a model.
When creating a model using retrain.py, it is necessary to specify the data set, so specify it after `--image_dir`.
```
python retrain.py --image_dir dataset
```
In addition, arguments can be specified, and the output destination of the model and the number of trainings can be specified [^ 3].
[^ 3]: If `--tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/1` is specified in the argument, it will be output as mobilenet. mobilenet is a relatively lightweight model created for the purpose of using the results of machine learning on mobile terminals.
If you execute it without specifying the output destination, ** output_graph.pb ** and ** output_labels.txt ** will be output to ** / tmp **.
## 1.3 (Bonus) Make the model actually infer
You can check the inference result of the model using [label_image.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/label_image.py).
# 2. Convert the created model to tflite format
Convert the output ** output_graph.pb ** file to tflite (TensorFlow Lite) format.
## 2.1 Conversion for iOS
Since iOS uses a quantized model, specify `QUANTIZED_UINT8` for` --inference_type` and `--inference_input_type`.
```
tflite_convert \
--graph_def_file=/tmp/output_graph.pb \
--output_file=./quant_graph.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--input_shape=1,299,299,3 \
--input_array=Placeholder \
--output_array=final_result \
--input_data_type=FLOAT \
--default_ranges_min=0 \
--default_ranges_max=6 \
--inference_type=QUANTIZED_UINT8 \
--inference_input_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_dev_values=128 \
```
## 2.2 Conversion for Android
On my Android, the GPU didn't support the quantized model, so I'll use the model for FLOAT. Specify `FLOAT` for` --inference_type` and `--inference_input_type`. Also change `--output_file` to` float_graph.tflite`.
```
tflite_convert \
--graph_def_file=/tmp/output_graph.pb \
--output_file=./float_graph.tflite \
⋮
--inference_type=FLOAT \
--inference_input_type=FLOAT \
⋮
```
## 2.3 Problems when converting to tflite
--The command to convert to tflite differs depending on the version of TensorFlow.
--TensorFlow 1. The model created by X series could not be converted by the script of 2.X series.
--I didn't know what to specify for `--input_array` or` --output_array`.
Create the following script to know what to specify for `--input_array` and` --output_array`.
```python
import tensorflow as tf
gf = tf.GraphDef()
m_file = open('/tmp/output_graph.pb','rb')
gf.ParseFromString(m_file.read())
with open('somefile.txt', 'a') as the_file:
for n in gf.node:
the_file.write(n.name+'\n')
file = open('somefile.txt','r')
data = file.readlines()
print ("Output name = ")
print (data[len(data)-1])
print ("Input name = ")
file.seek ( 0 )
print (file.readline())
```
The execution result looks like this.
```
Output name =
final_result
Input name =
Placeholder
```
# 3. Try it on mobile
Use the source code found in [tensorflow / examples](https://github.com/tensorflow/examples).
```
git clone https://github.com/tensorflow/examples.git
```
3.1 iOS
1. Open the project according to [README.md](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/ios)
2. Select ImageClassification / ImageClassification / Model cmd ⌘ + click-> ʻAdd Files to" ImageClassification "...` to add ** quant_graph.tflite ** and ** output_labels.txt **
![image.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/425587/04212b56-0ad9-3939-93ad-cca1cac7e67b.png)
3. Rewrite ImageClassification / ImageClassification / ModelDataHandler / ModelDataHandler.swift
4. Change "mobilenet_quant_v1_224" on line 37 to ** "quant_graph" **
5. Change "labels" on line 38 to ** "output_labels" **
6. Change the value of inputWidth on line 58 to ** 299 **
7. Change the value of inputHeight on line 59 to ** 299 **
#### **`.swift`**
```
enum MobileNet {
static let modelInfo: FileInfo = (name: "quant_graph", extension: "tflite")
static let labelsInfo: FileInfo = (name: "output_labels", extension: "txt")
}
~Abbreviation~
// MARK: - Model Parameters
let batchSize = 1
let inputChannels = 3
let inputWidth = 299
let inputHeight = 299
```
3.2 Android
1. Open `\ examples \ lite \ examples \ image_classification \ android` in Android Studio
2. Place ** float_graph.tflite ** and ** output_labels.txt ** in `\ app \ src \ main \ assets`
3. Rewrite app \ src \ main \ java \ org \ tensorflow \ lite \ examples \ classification \ tflite \ ClassifierFloatMobileNet.java
4. Change line 55 from "mobilenet_v1_1.0_224.tflite" to ** "float_graph.tflite" **
5. Changed line 60 from "labels.txt" to ** "output_labels.txt" **
```java
@Override
protected String getModelPath() {
// you can download this file from
// see build.gradle for where to obtain this file. It should be auto
// downloaded into assets.
return "float_graph.tflite";
}
@Override
protected String getLabelPath() {
return "output_labels.txt";
}
```