I usually use Flask a lot, but my acquaintance said, "Fast API is good!", So I decided to make a simple image recognition API. However, I didn't see many Japanese articles on FastAPI and ML, so I decided to create this article instead of a memo!
In this article, after preparing the development environment, we will give a brief explanation of the API server and front end.
All the code used this time is published on Github. ** (The folder structure of the implementation below is described on the premise of Github. The download of the sample model is also described in README.md.) **
It's one of the Python frameworks like Flask.
For a simple overview and a summary of how to use it, please refer to the following article. (Thank you very much for your help in this article too!)
https://qiita.com/bee2/items/75d9c0d7ba20e7a4a0e9
For those who want to know more, we recommend the official FastAPI tutorials!
https://fastapi.tiangolo.com/tutorial/
I didn't have time this time, so I will build it using the model of tensorflow.keras!
Specifically, we will use ResNet50 learned by imagenet as it is and infer which of the 1000 classes the input image belongs to.
(The model I really wanted to use wasn't in time because I was learning acclaim right now ...)
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras?hl=ja
Mac OS X Mojave Python3.7.1(Anaconda)
Install the required Python libraries.
$pip install tensorflow==1.15
$pip install fastapi
$pip install uvicorn
Since there are the following conditions, install the necessary libraries as well. --Render index.html --Upload image file --Load and resize the image
$pip install Jinja
$pip install aiofiles
$pip install python-multipart
$pip install opencv-python
The implementation of the API server is as follows.
# -*- coding: utf-8 -*-
import io
from typing import List
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import decode_predictions
from fastapi import FastAPI, Request, File, UploadFile
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
#Preparation of image recognition model
global model, graph
graph = tf.get_default_graph()
model = tf.keras.models.load_model("./static/model/resnet_imagenet.h5")
#Preparation of Fast API
app = FastAPI()
# static/js/post.index js.Required to call from html
app.mount("/static", StaticFiles(directory="static"), name="static")
#Index stored under templates.Needed to render html
templates = Jinja2Templates(directory="templates")
def read_image(bin_data, size=(224, 224)):
"""Load image
Arguments:
bin_data {bytes} --Image binary data
Keyword Arguments:
size {tuple} --Image size you want to resize(default: {(224, 224)})
Returns:
numpy.array --image
"""
file_bytes = np.asarray(bytearray(bin_data.read()), dtype=np.uint8)
img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, size)
return img
@app.post("/api/image_recognition")
async def image_recognition(files: List[UploadFile] = File(...)):
"""Image recognition API
Keyword Arguments:
files {List[UploadFile]} --Uploaded file information(default: {File(...)})
Returns:
dict --Inference result
"""
bin_data = io.BytesIO(files[0].file.read())
img = read_image(bin_data)
with graph.as_default():
pred = model.predict(np.expand_dims(img, axis=0))
result_label = decode_predictions(pred, top=1)[0][0][1]
return {"response": result_label}
@app.get("/")
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/api/image_recognition")
async def image_recognition(files: List[UploadFile] = File(...)):
"""Image recognition API
Keyword Arguments:
files {List[UploadFile]} --Uploaded file information(default: {File(...)})
Returns:
dict --Inference result
"""
bin_data = io.BytesIO(files[0].file.read())
img = read_image(bin_data)
with graph.as_default():
pred = model.predict(np.expand_dims(img, axis=0))
result_label = decode_predictions(pred, top=1)[0][0][1]
return {"response": result_label}
This time, we are using the Fast API Upload File to get the POSTed image.
bin_data = io.BytesIO(files[0].file.read())
Since only one file is POSTed, it is set as files [0], and since it is passed in BASE64 format from the front side, it was converted to a Bytes array on the API side.
def read_image(bin_data, size=(224, 224)):
"""Load image
Arguments:
bin_data {bytes} --Image binary data
Keyword Arguments:
size {tuple} --Image size you want to resize(default: {(224, 224)})
Returns:
numpy.array --image
"""
file_bytes = np.asarray(bytearray(bin_data.read()), dtype=np.uint8)
img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, size)
return img
With the help of opencv, it converts a Byte array to a uint8 image. At this time, since the default format of opencv is BGR, I converted it to RGB and resized it.
global model, graph
graph = tf.get_default_graph()
model = tf.keras.models.load_model("./static/model/resnet_imagenet.h5")
...
with graph.as_default():
pred = model.predict(np.expand_dims(img, axis=0))
result_label = decode_predictions(pred, top=1)[0][0][1]
I created resnet_imagenet.h5 in advance and read it at the top of the file. The inference process itself is inferred by the predict function by fixing the context in this thread to the TensorFlow graph set globally with with graph.as_default ().
Since we are using ResNet50 from tf.keras this time, we are using decode_predictions to convert the result of predict to a label to get the inference result.
I think that other models and self-made models can be used like this implementation by saving the .h5 file somewhere in the project directory and loading_model it.
I used this as a reference. (Thank you!)
https://qiita.com/katsunory/items/9bf9ee49ee5c08bf2b3d
<html>
<head>
<meta http-qeuiv="Content-Type" content="text/html; charset=utf-8">
<title>Fastapi image recognition test</title>
<script src="//code.jquery.com/jquery-2.2.3.min.js"></script>
<script src="/static/js/post.js"></script>
</head>
<body>
<!--File selection button-->
<div style="width: 500px">
<form enctype="multipart/form-data" method="post">
<input type="file" name="userfile" accept="image/*">
</form>
</div>
<!--Image display area-->
<canvas id="canvas" width="0" height="0"></canvas>
<!--Upload start button-->
<button class="btn btn-primary" id="post">Post</button>
<br>
<h2 id="result"></h2>
</body>
</html>
//Resize the image and display it in HTML
$(function () {
var file = null;
var blob = null;
const RESIZED_WIDTH = 300;
const RESIZED_HEIGHT = 300;
$("input[type=file]").change(function () {
file = $(this).prop("files")[0];
//File check
if (file.type != "image/jpeg" && file.type != "image/png") {
file = null;
blob = null;
return;
}
var result = document.getElementById("result");
result.innerHTML = "";
//Resize the image
var image = new Image();
var reader = new FileReader();
reader.onload = function (e) {
image.onload = function () {
var width, height;
//Resize to fit the longer one
if (image.width > image.height) {
var ratio = image.height / image.width;
width = RESIZED_WIDTH;
height = RESIZED_WIDTH * ratio;
} else {
var ratio = image.width / image.height;
width = RESIZED_HEIGHT * ratio;
height = RESIZED_HEIGHT;
}
var canvas = $("#canvas").attr("width", width).attr("height", height);
var ctx = canvas[0].getContext("2d");
ctx.clearRect(0, 0, width, height);
ctx.drawImage(
image,
0,
0,
image.width,
image.height,
0,
0,
width,
height
);
//Get base64 image data from canvas and create Blob for POST
var base64 = canvas.get(0).toDataURL("image/jpeg");
var barr, bin, i, len;
bin = atob(base64.split("base64,")[1]);
len = bin.length;
barr = new Uint8Array(len);
i = 0;
while (i < len) {
barr[i] = bin.charCodeAt(i);
i++;
}
blob = new Blob([barr], { type: "image/jpeg" });
console.log(blob);
};
image.src = e.target.result;
};
reader.readAsDataURL(file);
});
//When the upload start button is clicked
$("#post").click(function () {
if (!file || !blob) {
return;
}
var name,
fd = new FormData();
fd.append("files", blob);
//POST to API
$.ajax({
url: "/api/image_recognition",
type: "POST",
dataType: "json",
data: fd,
processData: false,
contentType: false,
})
.done(function (data, textStatus, jqXHR) {
//If communication is successful, output the result
var response = JSON.stringify(data);
var response = JSON.parse(response);
console.log(response);
var result = document.getElementById("result");
result.innerHTML = "This image...「" + response["response"] + "Yanke";
})
.fail(function (jqXHR, textStatus, errorThrown) {
//If communication fails, an error message will be output.
var result = document.getElementById("result");
result.innerHTML = "Communication with the server failed...";
});
});
});
POST to the image recognition API using ajax and display the result.
As a result, it works like this!
(I wanted to make the front desk a little more fashionable ...)
I made an image recognition API to study Fast API. I don't think the implementation I made this time is a best practice, but I'm glad I was able to make something that works.
I don't know which framework to use in the future, but the Fast API is relatively easy to use and I thought I should switch from Flask.
** Last but not least, thank you to everyone who helped us! ** **
Recommended Posts