Comment utiliser le modèle entraîné Tensorflow 2.0 avec Kotlin / Java

introduction

Dans cet article, je vais vous montrer comment utiliser le modèle entraîné de tensorflow 2.0 avec Kotlin. L'exemple de code est uniquement Kotlin, mais je pense qu'il fonctionnera en Java de la même manière.

Dans la méthode introduite cette fois, KerasModelImport d'une bibliothèque appelée deeplearning4j est utilisée. Étant donné que tensorflow a une API Java, il n'est pas nécessaire d'utiliser une bibliothèque mineure telle que deeplearning4j ~~, mais comme la version tensorflow2.0 pour Java n'a pas encore été distribuée, deeplearning4j est provisoirement utilisé. Je vais l'utiliser. (* Si vous le construisez vous-même, vous pourrez peut-être utiliser l'API Java compatible avec tensorflow 2.0)

En d'autres termes

Je souhaite exécuter le traitement d'inférence Deep Learning avec Kotlin / Java! Mais je ne veux pas écrire de code d'apprentissage avec deeplearning4j! !! J'ai hâte de voir la distribution de la version tensorflow 2.0 pour Java! !! !!

J'espère que vous pouvez le considérer comme une mesure de connexion pour une telle personne.

Version de la bibliothèque

Tensorflow et deeplearning4j changent radicalement d'une version à l'autre, il y a donc une forte possibilité que le comportement change en fonction de la version.

tensorflow(Python) : 2.1.0
deeplearning4j(Kotlin/Java) : 1.0.0-beta7

Code d'apprentissage (Python, tensorflow)

À titre d'exemple, le code Python pour l'apprentissage de MNIST avec tensorflow 2.0 est présenté ci-dessous.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model

#Si ce n'est pas float64, cela peut ne pas fonctionner correctement sur deeplearning4j.
tf.keras.backend.set_floatx('float64')

#L'importation du modèle séquentiel sous-classé échoue avec deeplearning4j
def make_functional_model(data_size, target_size):
    inputs = Input(data_size)
    fc1 = Dense(512, activation='relu')
    fc2 = Dense(target_size, activation='softmax')
    outputs = fc2(fc1(inputs))
    return Model(inputs=inputs, outputs=outputs)

def train(dataset, model, criterion, optimizer):
    for data, target in dataset:
        with tf.GradientTape() as tape:
            output = model(data)
            loss = criterion(target, output)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

def main():

    #Préparation pour MNIST
    mnist = tf.keras.datasets.mnist
    (train_data, train_target), _ = mnist.load_data()
    train_data = train_data / 255.0
    train_data = np.reshape(train_data, (train_data.shape[0], -1))
    train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)

    #Préparation du modèle
    data_size = train_data.shape[1]
    target_size = 10
    model = make_functional_model(data_size, target_size)

    #Apprentissage
    criterion = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam()
    for epoch in range(5):
        train(train_dataset, model, criterion, optimizer)
        
    #Le modèle enregistré ne parvient pas à importer avec deeplearning4j
    model.save('checkpoint.h5')

if __name__ == '__main__':
    main()

C'est tout pour le code d'apprentissage. De plus, le fonctionnement normal a été confirmé avec deeplearning4j uniquement lorsque "Enregistrer le modèle fonctionnel au format hdf5". Il n'a pas été possible d'importer dans la description du modèle au format séquentiel ou sous-classé ou au format enregistré du modèle enregistré.

Code d'inférence (Kotlin, deep learning4j)

Vient ensuite le code d'inférence. C'est Kotlin qui utilise deeplearning4j.

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport

fun main() {

    val mnist = MnistDataSetIterator(1, false, 123)
    val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
    //val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")

    mnist.forEach {
        val data = it.features
        val output = model.output(data)[0].toFloatVector()
        val pred = output.indexOf(output.max()!!)
    }
}

Il suffit d'importer le modèle et de déduire avec la fonction de sortie. C'est très simple à utiliser. Il existe également une fonction d'importation pour le modèle séquentiel, mais comme mentionné ci-dessus, l'importation échoue.

en conclusion

Dans cet article, j'ai présenté comment utiliser le modèle entraîné de tensorflow 2.0 avec Kotlin / Java. J'espère que l'API Java de tensorflow sera disponible de manière stable, et je la supporterai maintenant. ..

Et faisons tous du Deep Learning avec Kotlin! (C'est ce que je veux dire le plus)

Recommended Posts

Comment utiliser le modèle entraîné Tensorflow 2.0 avec Kotlin / Java
[Java] [Maven3] Résumé de l'utilisation de Maven3
Comment utiliser le framework Java avec AWS Lambda! ??
Comment utiliser l'API Java avec des expressions lambda
[Java] Comment utiliser Map
Comment utiliser java Facultatif
Comment utiliser la classe Java
[Java] Comment utiliser removeAll ()
Comment utiliser Java Map
Comment utiliser les variables Java
Résumé de l'API de communication Java (1) Comment utiliser Socket
Résumé de l'API de communication Java (3) Comment utiliser SocketChannel
Résumé de l'API de communication Java (2) Comment utiliser HttpUrlConnection
Comment utiliser HttpClient de Java (Post)
[Java] Comment utiliser la méthode de jointure
Comment utiliser setDefaultCloseOperation () de JFrame
[Traitement × Java] Comment utiliser les variables
Comment utiliser mssql-tools avec Alpine
[JavaFX] [Java8] Comment utiliser GridPane
Comment utiliser les méthodes de classe [Java]
[Java] Comment utiliser List [ArrayList]
Comment utiliser les classes en Java?
[Traitement × Java] Comment utiliser les tableaux
Comment utiliser les expressions Java lambda
[Java] Comment utiliser la classe Math
Comment utiliser le type enum Java
Résumé de l'utilisation du jeu de proxy dans IE lors de la connexion avec Java
Prise en charge multilingue de Java Comment utiliser les paramètres régionaux
[Java] Comment utiliser la classe File
Comment compiler Java avec VsCode & Ant
[Java] Résumez comment comparer avec la méthode equals
Comment utiliser BootStrap avec Play Framework
[java] Résumé de la gestion des caractères
Comment utiliser la méthode de soumission (Java Silver)
[Explication facile à comprendre! ] Comment utiliser l'instance Java
[Java] Comment utiliser la méthode toString ()
Etudier comment utiliser le constructeur (java)
[Traitement × Java] Comment utiliser la boucle
Comment utiliser et définir les classes Java, importer
[Explication facile à comprendre! ] Comment utiliser le polymorphisme Java
[Traitement × Java] Comment utiliser la classe
Comment utiliser la classe Java Scanner (Remarque)
[Traitement × Java] Comment utiliser la fonction
[Explication facile à comprendre! ] Comment utiliser ArrayList [Java]
[Java] Comment utiliser la classe Calendar
[Java] Découvrez comment utiliser correctement Optional
[Explication facile à comprendre! ] Comment utiliser la surcharge Java
gestion des exceptions try-catch-finally Comment utiliser java
[Explication facile à comprendre! ] Comment utiliser l'encapsulation Java
[Java] Comment utiliser une sous-chaîne pour découper une partie d'une chaîne de caractères
[Java] [ibatis] Comment obtenir des enregistrements de relation 1 à N avec List <Map <>>
[Java] Notez comment utiliser RecyclerView et comment implémenter le traitement par balayage animé.
Utilisez jenv pour activer plusieurs versions de Java