Machine learning model inference web API server configuration [Fast API implementation example available]

Purpose of this article

Introducing a typical configuration of a machine learning inference web API. I think that the content can be read without necessarily having knowledge of the WEB or machine learning. (Excluding implementation examples) The structure to be introduced comes from the experience of making inference web APIs for some machine learning models in business, but since it is my personal opinion, if there is something that is better, please let me know in the comments. I'm happy. In the implementation example, the web framework uses the Fast API from the viewpoint of ease of handling asynchronous processing and simplicity of implementation.

table of contents

  1. Machine learning inference web API configuration
  2. Implementation example

1. Machine learning inference web API configuration

In this article, I will introduce two patterns.

Note) First, I will explain the common parts. Machine learning knowledge is basically required only for the intersection. If you are not familiar with machine learning or the web, you can divide the roles between the common part and the part described later, so you can let it flow.

Inference API (intersection)

If you want the trained model to infer, you will generally build the following machine learning model inference API. Even if you are only developing on a local PC or Jupyter Notebook, I think you will create such an API (pipeline).

I will omit the details, but for the convenience of load distribution and model management, I think that you can cut out only the API that uses the machine learning model for the server on the cloud (Reference: [GCP AI platform Prediction](https: /) /cloud.google.com/ai-platform/prediction/docs)). In the case of a heavy model where performance problems occur unless GPU is used for inference as well as load, it is not possible to handle it with a server for common WEB applications, so I think it is more flexible to be able to isolate it. Also, I think that the same configuration will be used when using an external service that uses a trained model.

online_vs_batch-Copy of online prediction API.png

As the amount of data increases, I think it will be necessary to take measures such as replacing preprocessing with a large-scale data processing engine such as Google Cloud Dataflow.

When building a web API based on the inference API developed on a local PC or Jupyter Notebook as described above, there are mainly two types of patterns that can be considered. These handle I / O data differently.

--1.1. Online forecasting (also called HTTP forecasting) --1.2. Batch prediction

(The name used in GCP's AI platform is used. Reference: [Online prediction vs. batch prediction](https://cloud.google.com/ml-engine/docs/tensorflow/online-vs-batch- prediction? hl = ja)))

1.1. Online forecast

online prediction API.png

It is a simple configuration that operates the ML function when an http request arrives and immediately returns the output with an http response. Load the weight only once when the server starts. When loading weights, it will be easier to change the model if you get the weights from cloud storage (google storage etc.).

advantage

--Simply move the inference function that works locally into the web framework and it will work. ――Inference results are returned just by hitting one API, so it becomes an API that is easy to hit. --Fast response when the model is small and the data is small

Disadvantage

--The web API is often set to time out in tens of seconds to minutes from the viewpoint of load distribution, so if it takes a long time to infer, the process will fail. So it's not suitable for heavy models or handling large amounts of data at once.

1.2. Batch Prediction

batch prediction API.png

If the response cannot be returned immediately or it is not necessary to return it, the inference result of ML API is stored in some storage without directly responding as shown below. The process can be divided into three stages as shown below. (It is good if 2 and 3 are separated. The upload API can be integrated with the ML API)

  1. upload API: Store data for input in Storage (Database, cloud storage, etc.)
  2. ML API (Asynchronous execution): Get data from Storage, run ML functions and save the result in Storage. However, the response is returned before the processing is completed.
  3. download API: Get results from Storage and return

Each API can be loosely coupled. Therefore, the implementation of upload API and download API is quite flexible. There are various ways to use it as follows.

--Accumulate input data for a certain period of time and infer at once at the end of the day --Inference using a complicated model that times out --Cache the inference result and do not repeat inference for the same input

Also, the upload and download API implementations are fine in languages other than Python, and the APIs can be on different servers as long as they can read and write to the same storage. You can read and write to Storage directly from the front end without going through the API. Especially when the input / output is an image, it is a simpler flow to handle cloud storage directly.

advantage

--No more failures due to timeout --High degree of freedom --The API for learning can be implemented with the same configuration.

Disadvantage

――It is difficult to use because it is more complicated than online prediction. --It takes longer to process than online prediction

2. Implementation example

Let's implement the online prediction and batch prediction APIs with the Fast API. Looking at the example below, I think that if you properly function the inference pipeline locally, you will feel that the hurdle to make it a web API is quite low.

What not to do

This article does not cover the following:

What is FastAPI

A Python web framework that is a micro-framework like Flask. Its strengths include high performance, ease of writing, a design that is strongly conscious of production operation, and modern functions. Asynchronous processing is especially easy to handle.

The following assumes basic knowledge of Fast API. If you want to know the details, please refer to the following as appropriate.

-Official Doc -[FastAPI] Getting Started with Python's ASGI Web Framework FastAPI

Inference API (intersection)

To be versatile, we define a very rough mock. It doesn't mean anything, but it's easy, so I'll call it a sentiment analysis task for natural language processing.

The required functions are as follows. However, if only the model is cut out to another server, it is not necessary to keep the load and model.

--Loading model instance using weight load, joblib, pickle, etc. --Model retention --Inference pipeline

This time, we will use a model that returns random emotions with predict. I want to make the processing time real, so I freeze it for 20 seconds when loading and freeze for 10 seconds when predicting.

ml.py


from random import choice
from time import sleep

class MockMLAPI:
    def __init__(self):
        # model instanse
        self.model = None

    def load(self, filepath=''):
        """
        when server is activated, load weight or use joblib or pickle for performance improvement.
        then, assign pretrained model instance to self.model.
        """
        sleep(20)
        pass

    def predict(self, x):
        """implement followings
        - Load data
        - Preprocess
        - Prediction using self.model
        - Post-process
        """
        sleep(10)
        preds = [choice(['happy', 'sad', 'angry']) for i in range(len(x))]
        out = [{'text': t.text, 'sentiment': s} for t, s in zip(x, preds)]
        return out

Request / response data format

Defines the format of the request data. Let's try to support multiple inputs as shown below.

{
  "data": [
    {"text": "hogehoge"},
    {"text": "fugafuga"}
  ]
}

The response data should be in a format that adds the inference result to the input and returns it.

{
  "prediction": [
    {"text": "hogehoge", "sentiment": "angry"},
    {"text": "fugafuga", "sentiment": "sad"}
  ]
}

So, define Schema as follows.

schemas.py


from pydantic import BaseModel
from typing import List

# request
class Text(BaseModel):
    text: str

class Data(BaseModel):
    data: List[Text]

# response
class Output(Text):
    sentiment: str

class Pred(BaseModel):
    prediction: List[Output]

2.1. Online forecast

Implement a web API that makes online predictions using the intersections mentioned above. All you need is

--Load a trained machine learning model when the server starts --Receive data, infer with ML API, return result

is. The minimum API is completed by implementing as follows.

main.py


from fastapi import FastAPI
from ml_api import schemas
from ml_api.ml import MockMLAPI

app = FastAPI()
ml = MockMLAPI()
ml.load() # load weight or model instanse using joblib or pickle

@app.post('/prediction/online', response_model=schemas.Pred)
async def online_prediction(data: schemas.Data):
    preds = ml.predict(data.data)
    return {"prediction": preds}

Operation check

Check the operation locally. Post the sample input in CuRL. Then, you can confirm that the expected output is returned. Also, since it took 10 seconds for the response to be returned, you can see that it took almost only the predict processing time.

$ curl -X POST "http://localhost:8000/prediction/online" -H "accept: application/json" -H "Content-Type: application/json" -d "{\"data\":[{\"text\":\"hogehoge\"},{\"text\":\"fugafuga\"}]}" -w  "\nelapsed time: %{time_starttransfer} s\n"

{"prediction":[{"text":"hogehoge","sentiment":"angry"},{"text":"fugafuga","sentiment":"happy"}]}
elapsed time: 10.012029 s

2.1. Batch Prediction

Implement a web API that makes batch predictions using the intersections mentioned above.

  1. upload API: Store data for input in Storage (Database, cloud storage, etc.)
  2. ML API (Asynchronous execution): Get data from Storage, run ML functions and save the result in Storage. However, the response is returned before the processing is completed.
  3. download API: Get results from Storage and return

Input/Output Normally, you should save the data in cloud storage or DB, but for the sake of simplicity, in this article we will save the data in csv format in the local storage. First, define a function for reading and writing. When saving the input data, a file name is created with a random character string, and a series of batch predictions are performed by exchanging the random character string with the api. The implementation may seem long, but in reality there are only three things to do:

--Reading and writing csv --Adjust file path --Random string generation

io.py


import os
import csv
from random import choice
import string
from typing import List
from ml_api import schemas

storage = os.path.join(os.path.dirname(__file__), 'local_storage')

def save_csv(data, filepath: str, fieldnames=None):
    with open(filepath, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)

        writer.writeheader()
        for f in data:
            writer.writerow(f)

def load_csv(filepath: str):
    with open(filepath, 'r') as f:
        reader = csv.DictReader(f)
        out = list(reader)
    return out

def save_inputs(data: schemas.Data, length=8):
    letters = string.ascii_lowercase
    filename = ''.join(choice(letters) for i in range(length)) + '.csv'
    filepath = os.path.join(storage, 'inputs', filename)
    save_csv(data=data.dict()['data'], filepath=filepath, fieldnames=['text'])
    return filename

def load_inputs(filename: str):
    filepath = os.path.join(storage, 'inputs', filename)
    texts = load_csv(filepath=filepath)
    texts = [schemas.Text(**f) for f in texts]
    return texts

def save_outputs(preds: List[str], filename):
    filepath = os.path.join(storage, 'outputs', filename)
    save_csv(data=preds, filepath=filepath, fieldnames=['text', 'sentiment'])
    return filename

def load_outputs(filename: str):
    filepath = os.path.join(storage, 'outputs', filename)
    return load_csv(filepath=filepath)

def check_outputs(filename: str):
    filepath = os.path.join(storage, 'outputs', filename)
    return os.path.exists(filepath)

web API Build three APIs: upload, inference, and download. Note that batch inference does not return an immediate response, so load the model every time the API is hit.

Here, BackgourndTasks of FastAPI is used to asynchronously process model inference. Inference can be processed in the background and the response can be returned first without waiting for the end.

main.py


from fastapi import FastAPI
from fastapi import BackgroundTasks
from fastapi import HTTPException
from ml_api import schemas, io
from ml_api.ml import MockBatchMLAPI

app = FastAPI()

@app.post('/upload')
async def upload(data: schemas.Data):
    filename = io.save_inputs(data)
    return {"filename": filename}

def batch_predict(filename: str):
    """batch predict method for background process"""
    ml = MockMLAPI()
    ml.load()
    data = io.load_inputs(filename)
    pred = ml.predict(data)
    io.save_outputs(pred, filename)
    print('finished prediction')

@app.get('/prediction/batch')
async def batch_prediction(filename: str, background_tasks: BackgroundTasks):
    if io.check_outputs(filename):
        raise HTTPException(status_code=404, detail="the result of prediction already exists")

    background_tasks.add_task(ml.batch_predict, filename)
    return {}

@app.get('/download', response_model=schemas.Pred)
async def download(filename: str):
    if not io.check_outputs(filename):
        raise HTTPException(status_code=404, detail="the result of prediction does not exist")

    preds = io.load_outputs(filename)
    return {"prediction": preds}

Operation check

Check the operation in the same way as online prediction. Post the sample input in CuRL. Then, you can confirm that the expected output is returned. It also waits 30 seconds before hitting the download API. However, you can see that each response is returning very quickly.

$ curl -X POST "http://localhost:8000/upload" -H "accept: application/json" -H "Content-Type: application/json" -d "{\"data\":[{\"text\":\"hogehoge\"},{\"text\":\"fugafuga\"}]}" -w  "\nelapsed time: %{time_starttransfer} s\n"
{"filename":"fdlelteb.csv"}
elapsed time: 0.010242 s

$ curl -X GET "http://localhost:8000/prediction/batch?filename=fdlelteb.csv" -w  "\nelapsed time: %{time_starttransfer} s\n"
{}
elapsed time: 0.007223 s

$ curl -X GET "http://localhost:8000/download?filename=fdlelteb.csv" -w  "\nelapsed time: %{time_starttransfer} s\n"   [12:58:27]
{"prediction":[{"text":"hogehoge","sentiment":"happy"},{"text":"fugafuga","sentiment":"sad"}]}
elapsed time: 0.008825 s

in conclusion

We have introduced two typical configurations of machine learning inference web APIs: online prediction and batch prediction. It requires a little twist from the general web API configuration, but I also introduced an implementation example to build it simply using Fast API. It would be great if you could feel that the hurdles for making a web API are low if the inference pipeline is properly functionalized locally. The excitement of machine learning is endless, but I feel that there is still little information such as the configuration of the web API ~~ (It seems quite likely. I added a collection of links). I think the configuration introduced in this article is also rough. I would appreciate it if you could comment on any improvements!

Related Links

This is a link that could not be covered in this article.

-N things to keep in mind when putting the machine learning platform for inference into production -Publish the design pattern of the machine learning system. (Mercari Engineering Blog) -Machine learning system configuration example (reference)

Recommended Posts

Machine learning model inference web API server configuration [Fast API implementation example available]
Learning neural networks using Chainer-Creating a Web API server
Machine learning model considering maintainability