[Java] How to use the trained model of tensorflow 2.0 with Kotlin/Java

2 minute read


In this article, I will show you how to use the trained model of tensorflow 2.0 with Kotlin. The sample code is only Kotlin, but I think it works the same way in Java.

The method introduced this time uses KerasModelImport of a library called deeplearning4j. Since there is a Java API in tensorflow, it is not necessary to use a minor library ~~ such as deeplearning4j ~~, but since tensorflow2.0 build for Java has not been distributed yet, provisionally I will use it. (* You may be able to use the Java API compatible with tensorflow 2.0 if you build it yourself)

That is,

I want to run Deep Learning inference processing with Kotlin/Java! But I don’t want to write learning code with deeplearning4j! ! I can’t wait to distribute the tensorflow 2.0 build for Java! ! !

I would like you to consider it as a connecting measure for such people.

Library version

Both tensorflow and deeplearning4j change drastically with each version, so there is a high possibility that the behavior will change if the version is different.

tensorflow(Python): 2.1.0
deeplearning4j(Kotlin/Java): 1.0.0-beta7

Learning code (Python, tensorflow)

As a sample, the Python code for MNIST learning with tensorflow 2.0 is shown below.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model

If it is not # float64, it may not work properly on deeplearning4j

# Sequential, SubClassed model fails to import in deeplearning4j
def make_functional_model(data_size, target_size):
    inputs = Input(data_size)
    fc1 = Dense(512, activation='relu')
    fc2 = Dense(target_size, activation='softmax')
    outputs = fc2(fc1(inputs))
    return Model(inputs=inputs, outputs=outputs)

def train(dataset, model, criterion, optimizer):
    for data, target in dataset:
        with tf.GradientTape() as tape:
            output = model(data)
            loss = criterion(target, output)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

def main():

# Preparing MNIST
    mnist = tf.keras.datasets.mnist
    (train_data, train_target), _ = mnist.load_data()
    train_data = train_data / 255.0
    train_data = np.reshape(train_data, (train_data.shape[0], -1))
    train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)

# Model preparation
    data_size = train_data.shape[1]
    target_size = 10
    model = make_functional_model(data_size, target_size)

# Learning
    criterion = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam()
    for epoch in range(5):
        train(train_dataset, model, criterion, optimizer)
# Saved Model fails to import with deep learning4j

if __name__ =='__main__':

That’s all for the learning code. In addition, the normal operation was confirmed in deeplearning4j only when “Save the functional model in hdf5 format”. It was not possible to import the model description in Sequential or SubClassed format or the saved format of Saved Model.

Reasoning code (Kotlin, deeplearning4j)

Next is the inference code. This is deeplearning4j in Kotlin.

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport

fun main() {

    val mnist = MnistDataSetIterator(1, false, 123)
    val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
    //val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")

    mnist.forEach {
        val data = it.features
        val output = model.output(data)[0].toFloatVector()
        val pred = output.indexOf(output.max()!!)

Just import the model and infer it with the output function. It’s very easy to use. Although there is an import function for the Sequential model, the import fails as mentioned above.

in conclusion

This article introduced how to use the trained model of tensorflow 2.0 with Kotlin/Java. I endure it now, hoping that the tensorflow Java API will be stable and available in the future. ..

And let’s all do Deep Learning with Kotlin! (This is what I want to say the most)