How to use trained model of tensorflow2.0 with Kotlin / Java

Introduction

In this article, I will show you how to use the trained model of tensorflow 2.0 with Kotlin. The sample code is Kotlin only, but I think it will work in Java in a similar way.

In the method introduced this time, KerasModelImport of the library called deeplearning4j is used. Since Java API exists in tensorflow, it is not necessary to use minor libraries such as deeplearning4j ~~, but since the tensorflow2.0 build for Java has not been distributed yet, deeplearning4j is tentatively used. I will use it. (* If you build it yourself, you may be able to use Java API compatible with tensorflow 2.0)

In other words

I want to run Deep Learning inference processing in Kotlin / Java! But I don't want to write learning code with deeplearning4j! !! I can't wait for the distribution of the tensorflow 2.0 build for Java! !! !!

I hope you can think of it as a connection measure for such a person.

Library version

Both tensorflow and deeplearning4j change drastically from version to version, so there is a high possibility that the behavior will change depending on the version.

tensorflow(Python) : 2.1.0
deeplearning4j(Kotlin/Java) : 1.0.0-beta7

Learning code (Python, tensorflow)

As a sample, the Python code for learning MNIST with tensorflow 2.0 is shown below.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model

#If it is not float64, it may not work properly on deeplearning4j.
tf.keras.backend.set_floatx('float64')

#Sequential, SubClassed model fails to import with deeplearning4j
def make_functional_model(data_size, target_size):
    inputs = Input(data_size)
    fc1 = Dense(512, activation='relu')
    fc2 = Dense(target_size, activation='softmax')
    outputs = fc2(fc1(inputs))
    return Model(inputs=inputs, outputs=outputs)

def train(dataset, model, criterion, optimizer):
    for data, target in dataset:
        with tf.GradientTape() as tape:
            output = model(data)
            loss = criterion(target, output)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

def main():

    #Preparing for MNIST
    mnist = tf.keras.datasets.mnist
    (train_data, train_target), _ = mnist.load_data()
    train_data = train_data / 255.0
    train_data = np.reshape(train_data, (train_data.shape[0], -1))
    train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)

    #Model preparation
    data_size = train_data.shape[1]
    target_size = 10
    model = make_functional_model(data_size, target_size)

    #Learning
    criterion = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam()
    for epoch in range(5):
        train(train_dataset, model, criterion, optimizer)
        
    #Saved Model fails to import on deeplearning4j
    model.save('checkpoint.h5')

if __name__ == '__main__':
    main()

That's all for the learning code. In addition, normal operation was confirmed with deeplearning4j only when "Save Functional model in hdf5 format". It was not possible to import using the model description in Sequential or SubClassed format or the saved format of Saved Model.

Inference code (Kotlin, deeplearning4j)

Next is the inference code. This is Kotlin using deeplearning4j.

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport

fun main() {

    val mnist = MnistDataSetIterator(1, false, 123)
    val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
    //val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")

    mnist.forEach {
        val data = it.features
        val output = model.output(data)[0].toFloatVector()
        val pred = output.indexOf(output.max()!!)
    }
}

Just import the model and infer with the output function. It's very easy to use. There is also an import function for the Sequential model, but as mentioned above, the import fails.

in conclusion

In this article, I introduced how to use the trained model of tensorflow 2.0 with Kotlin / Java. I hope that the Java API of tensorflow will be available in a stable manner, and I will endure it now. ..

And let's all do Deep Learning with Kotlin! (This is what I want to say the most)

Recommended Posts

How to use trained model of tensorflow2.0 with Kotlin / Java
[Java] [Maven3] Summary of how to use Maven3
How to use Java framework with AWS Lambda! ??
How to use Java API with lambda expression
[Java] How to use Map
How to use java Optional
How to use java class
[Java] How to use removeAll ()
How to use Java Map
How to use Java variables
Summary of Java communication API (1) How to use Socket
Summary of Java communication API (3) How to use SocketChannel
Summary of Java communication API (2) How to use HttpUrlConnection
How to use Java HttpClient (Get)
How to use Java HttpClient (Post)
[Java] How to use join method
How to use setDefaultCloseOperation () of JFrame
[Processing × Java] How to use variables
How to use mssql-tools with alpine
[JavaFX] [Java8] How to use GridPane
How to use class methods [Java]
[Java] How to use List [ArrayList]
How to use classes in Java?
[Processing × Java] How to use arrays
How to use Java lambda expressions
[Java] How to use Math class
How to use Java enum type
Summary of how to use the proxy set in IE when connecting with Java
Multilingual Locale in Java How to use Locale
[Java] How to use the File class
How to compile Java with VsCode & Ant
[Java] How to compare with equals method
How to use BootStrap with Play Framework
[java] Summary of how to handle char
How to use submit method (Java Silver)
[Easy-to-understand explanation! ] How to use Java instance
[Java] How to use the toString () method
Studying how to use the constructor (java)
[Processing × Java] How to use the loop
How to use Java classes, definitions, import
[Easy-to-understand explanation! ] How to use Java polymorphism
[Processing × Java] How to use the class
How to use Java Scanner class (Note)
[Processing × Java] How to use the function
[Easy-to-understand explanation! ] How to use ArrayList [Java]
[Java] How to use the Calendar class
[Java] Learn how to use Optional correctly
[Easy-to-understand explanation! ] How to use Java overload
try-catch-finally exception handling How to use java
[Easy-to-understand explanation! ] How to use Java encapsulation
[Java] How to use substring to cut out a part of a character string
[Java] [ibatis] How to get records of 1-to-N relationship with List <Map <>>
[Java] Note how to use RecyclerView and implementation of animated swipe processing.
Use jenv to enable multiple versions of Java
[Java] How to test for null with JUnit
[Java] How to use FileReader class and BufferedReader class
[java] Summary of how to handle character strings
How to use MyBatis2 (iBatis) with Spring Boot 1.4 (Spring 4)
[Java] Summary of how to abbreviate lambda expressions
How to use built-in h2db with spring boot
I want to use java8 forEach with index