I tried to make a simple image recognition API with Fast API and Tensorflow

Introduction

I usually use Flask a lot, but my acquaintance said, "Fast API is good!", So I decided to make a simple image recognition API. However, I didn't see many Japanese articles on FastAPI and ML, so I decided to create this article instead of a memo!

In this article, after preparing the development environment, we will give a brief explanation of the API server and front end.

All the code used this time is published on Github. ** (The folder structure of the implementation below is described on the premise of Github. The download of the sample model is also described in README.md.) **

What is FastAPI?

It's one of the Python frameworks like Flask.

For a simple overview and a summary of how to use it, please refer to the following article. (Thank you very much for your help in this article too!)

https://qiita.com/bee2/items/75d9c0d7ba20e7a4a0e9

For those who want to know more, we recommend the official FastAPI tutorials!

https://fastapi.tiangolo.com/tutorial/

About image recognition

I didn't have time this time, so I will build it using the model of tensorflow.keras!

Specifically, we will use ResNet50 learned by imagenet as it is and infer which of the 1000 classes the input image belongs to.

(The model I really wanted to use wasn't in time because I was learning acclaim right now ...)

https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras?hl=ja

Development environment

Mac OS X Mojave Python3.7.1(Anaconda)

Environment

Install the required Python libraries.

$pip install tensorflow==1.15
$pip install fastapi
$pip install uvicorn

Since there are the following conditions, install the necessary libraries as well. --Render index.html --Upload image file --Load and resize the image

$pip install Jinja
$pip install aiofiles
$pip install python-multipart
$pip install opencv-python

API server

The implementation of the API server is as follows.

# -*- coding: utf-8 -*-
import io
from typing import List

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import decode_predictions
from fastapi import FastAPI, Request, File, UploadFile
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates

#Preparation of image recognition model
global model, graph
graph = tf.get_default_graph()
model = tf.keras.models.load_model("./static/model/resnet_imagenet.h5")

#Preparation of Fast API
app = FastAPI()

# static/js/post.index js.Required to call from html
app.mount("/static", StaticFiles(directory="static"), name="static")

#Index stored under templates.Needed to render html
templates = Jinja2Templates(directory="templates")


def read_image(bin_data, size=(224, 224)):
    """Load image

    Arguments:
        bin_data {bytes} --Image binary data

    Keyword Arguments:
        size {tuple} --Image size you want to resize(default: {(224, 224)})

    Returns:
        numpy.array --image
    """
    file_bytes = np.asarray(bytearray(bin_data.read()), dtype=np.uint8)
    img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, size)
    return img


@app.post("/api/image_recognition")
async def image_recognition(files: List[UploadFile] = File(...)):
    """Image recognition API

    Keyword Arguments:
        files {List[UploadFile]} --Uploaded file information(default: {File(...)})

    Returns:
        dict --Inference result
    """
    bin_data = io.BytesIO(files[0].file.read())
    img = read_image(bin_data)
    with graph.as_default():
        pred = model.predict(np.expand_dims(img, axis=0))
        result_label = decode_predictions(pred, top=1)[0][0][1]
        return {"response": result_label}


@app.get("/")
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

Receive data from the front desk

@app.post("/api/image_recognition")
async def image_recognition(files: List[UploadFile] = File(...)):
    """Image recognition API

    Keyword Arguments:
        files {List[UploadFile]} --Uploaded file information(default: {File(...)})

    Returns:
        dict --Inference result
    """
    bin_data = io.BytesIO(files[0].file.read())
    img = read_image(bin_data)
    with graph.as_default():
        pred = model.predict(np.expand_dims(img, axis=0))
        result_label = decode_predictions(pred, top=1)[0][0][1]
        return {"response": result_label}

This time, we are using the Fast API Upload File to get the POSTed image.

bin_data = io.BytesIO(files[0].file.read())

Since only one file is POSTed, it is set as files [0], and since it is passed in BASE64 format from the front side, it was converted to a Bytes array on the API side.

Convert data to images

def read_image(bin_data, size=(224, 224)):
    """Load image

    Arguments:
        bin_data {bytes} --Image binary data

    Keyword Arguments:
        size {tuple} --Image size you want to resize(default: {(224, 224)})

    Returns:
        numpy.array --image
    """
    file_bytes = np.asarray(bytearray(bin_data.read()), dtype=np.uint8)
    img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, size)
    return img

With the help of opencv, it converts a Byte array to a uint8 image. At this time, since the default format of opencv is BGR, I converted it to RGB and resized it.

Infer

global model, graph
graph = tf.get_default_graph()
model = tf.keras.models.load_model("./static/model/resnet_imagenet.h5")

...

with graph.as_default():
        pred = model.predict(np.expand_dims(img, axis=0))
        result_label = decode_predictions(pred, top=1)[0][0][1]

I created resnet_imagenet.h5 in advance and read it at the top of the file. The inference process itself is inferred by the predict function by fixing the context in this thread to the TensorFlow graph set globally with with graph.as_default ().

Since we are using ResNet50 from tf.keras this time, we are using decode_predictions to convert the result of predict to a label to get the inference result.

I think that other models and self-made models can be used like this implementation by saving the .h5 file somewhere in the project directory and loading_model it.

Front mounting

I used this as a reference. (Thank you!)

https://qiita.com/katsunory/items/9bf9ee49ee5c08bf2b3d

<html>
<head>
    <meta http-qeuiv="Content-Type" content="text/html; charset=utf-8">
    <title>Fastapi image recognition test</title>
    <script src="//code.jquery.com/jquery-2.2.3.min.js"></script>
    <script src="/static/js/post.js"></script>
</head>

<body>

<!--File selection button-->
<div style="width: 500px">
  <form enctype="multipart/form-data" method="post">
    <input type="file" name="userfile" accept="image/*">
  </form>
</div>

<!--Image display area-->
<canvas id="canvas" width="0" height="0"></canvas>

<!--Upload start button-->
<button class="btn btn-primary" id="post">Post</button>
<br>
<h2 id="result"></h2>
</body>
</html>

//Resize the image and display it in HTML
$(function () {
  var file = null;
  var blob = null;
  const RESIZED_WIDTH = 300;
  const RESIZED_HEIGHT = 300;

  $("input[type=file]").change(function () {
    file = $(this).prop("files")[0];

    //File check
    if (file.type != "image/jpeg" && file.type != "image/png") {
      file = null;
      blob = null;
      return;
    }

    var result = document.getElementById("result");
    result.innerHTML = "";

    //Resize the image
    var image = new Image();
    var reader = new FileReader();
    reader.onload = function (e) {
      image.onload = function () {
        var width, height;

        //Resize to fit the longer one
        if (image.width > image.height) {
          var ratio = image.height / image.width;
          width = RESIZED_WIDTH;
          height = RESIZED_WIDTH * ratio;
        } else {
          var ratio = image.width / image.height;
          width = RESIZED_HEIGHT * ratio;
          height = RESIZED_HEIGHT;
        }

        var canvas = $("#canvas").attr("width", width).attr("height", height);
        var ctx = canvas[0].getContext("2d");
        ctx.clearRect(0, 0, width, height);
        ctx.drawImage(
          image,
          0,
          0,
          image.width,
          image.height,
          0,
          0,
          width,
          height
        );

        //Get base64 image data from canvas and create Blob for POST
        var base64 = canvas.get(0).toDataURL("image/jpeg");
        var barr, bin, i, len;
        bin = atob(base64.split("base64,")[1]);
        len = bin.length;
        barr = new Uint8Array(len);
        i = 0;
        while (i < len) {
          barr[i] = bin.charCodeAt(i);
          i++;
        }
        blob = new Blob([barr], { type: "image/jpeg" });
        console.log(blob);
      };
      image.src = e.target.result;
    };
    reader.readAsDataURL(file);
  });

  //When the upload start button is clicked
  $("#post").click(function () {
    if (!file || !blob) {
      return;
    }

    var name,
      fd = new FormData();
    fd.append("files", blob);

    //POST to API
    $.ajax({
      url: "/api/image_recognition",
      type: "POST",
      dataType: "json",
      data: fd,
      processData: false,
      contentType: false,
    })
      .done(function (data, textStatus, jqXHR) {
          //If communication is successful, output the result
        var response = JSON.stringify(data);
        var response = JSON.parse(response);
        console.log(response);
        var result = document.getElementById("result");
        result.innerHTML = "This image...「" + response["response"] + "Yanke";
      })
      .fail(function (jqXHR, textStatus, errorThrown) {
          //If communication fails, an error message will be output.
        var result = document.getElementById("result");
        result.innerHTML = "Communication with the server failed...";
      });
  });
});

POST to the image recognition API using ajax and display the result.

Operation check

As a result, it works like this! demo.png

(I wanted to make the front desk a little more fashionable ...)

in conclusion

I made an image recognition API to study Fast API. I don't think the implementation I made this time is a best practice, but I'm glad I was able to make something that works.

I don't know which framework to use in the future, but the Fast API is relatively easy to use and I thought I should switch from Flask.

** Last but not least, thank you to everyone who helped us! ** **

Recommended Posts

I tried to make a simple image recognition API with Fast API and Tensorflow
I tried to make a Web API
I tried simple image recognition with Jupyter
Rubyist tried to make a simple API with Python + bottle + MySQL
I tried to make a periodical process with Selenium and Python
I tried to make a ○ ✕ game using TensorFlow
I tried to make "Sakurai-san" a LINE BOT with API Gateway + Lambda
I tried to make an image classification BOT by combining TensorFlow Lite and LINE Messaging API
I tried to implement Grad-CAM with keras and tensorflow
I tried to make a periodical process with CentOS7, Selenium, Python and Chrome
I tried to make a castle search API with Elasticsearch + Sudachi + Go + echo
I tried to make GUI tic-tac-toe with Python and Tkinter
I tried to make a simple text editor using PyQt
I tried to make something like a chatbot with the Seq2Seq model of TensorFlow
[5th] I tried to make a certain authenticator-like tool with python
I tried to implement Autoencoder with TensorFlow
[2nd] I tried to make a certain authenticator-like tool with python
I tried to make a 2channel post notification application with Python
I tried to create Bulls and Cows with a shell program
I tried to make a todo application using bottle with python
[4th] I tried to make a certain authenticator-like tool with python
[1st] I tried to make a certain authenticator-like tool with python
I tried to make a strange quote for Jojo with LSTM
I tried to make an image similarity function with Python + OpenCV
I tried to make a mechanism of exclusive control with Go
I tried to introduce a serverless chatbot linked with Rakuten API to Teams
Python: I tried to make a flat / flat_map just right with a generator
I tried to make a calculator with Tkinter so I will write it
I tried to make a traffic light-like with Raspberry Pi 4 (Python edition)
I tried to make Kana's handwriting recognition Part 2/3 Data creation and learning
I tried to make a url shortening service serverless with AWS CDK
I tried image recognition of CIFAR-10 with Keras-Learning-
I tried image recognition of CIFAR-10 with Keras-Image recognition-
I want to make a game with Python
I tried to make a thumbnail image of the best avoidance flag-chan! With RGB values ​​[Histogram] [Visualization]
When I tried to make a VPC with AWS CDK but couldn't make it
I tried to automate internal operations with Docker, Python and Twitter API + bonus
I tried to extract a line art from an image with Deep Learning
Load a photo and make a handwritten sketch. With zoom function. Tried to make it.
I tried to create a table only with Django
I tried to implement and learn DCGAN with PyTorch
I tried to draw a route map with Python
Let's make a simple game with Python 3 and iPhone
I tried to automatically read and save with VOICEROID2
I tried to uncover our darkness with Chatwork API
I tried to automatically generate a password with Python3
I tried a simple RPA for login with selenium
I tried to make an OCR application with PySimpleGUI
I tried to find an alternating series with tensorflow
I tried image recognition of "Moon and Soft-shelled Turtle" with Pytorch (using torchvision.datasets.ImageFolder which corresponds to from_from_directry of keras)
I tried to make a real-time sound source separation mock with Python machine learning
[Mac] I want to make a simple HTTP server that runs CGI with Python
Image processing with Python (I tried binarizing it into a mosaic art of 0 and 1)
I tried to make creative art with AI! I programmed a novelty! (Paper: Creative Adversarial Network)
I tried to implement a volume moving average with Quantx
I tried to find the entropy of the image with python
I tried to make various "dummy data" with Python faker
I tried to automatically create a report with Markov chain
I want to make a blog editor with django admin
I want to make a click macro with pyautogui (desire)
I tried follow management with Twitter API and Python (easy)