Verwendung des mit Tensorflow 2.0 trainierten Modells mit Kotlin / Java

Einführung

In diesem Artikel werde ich Ihnen zeigen, wie Sie das trainierte Modell von Tensorflow 2.0 mit Kotlin verwenden. Der Beispielcode ist nur Kotlin, aber ich denke, er wird in Java auf ähnliche Weise funktionieren.

In der diesmal eingeführten Methode wird KerasModelImport einer Bibliothek namens deeplearning4j verwendet. Da Tensorflow über eine Java-API verfügt, muss keine kleinere Bibliothek wie deeplearning4j ~~ verwendet werden. Da der Tensorflow2.0-Build für Java jedoch noch nicht verteilt wurde, wird deeplearning4j vorläufig verwendet. Ich werde es benutzen. (* Wenn Sie es selbst erstellen, können Sie möglicherweise die Java-API verwenden, die Tensorflow 2.0 unterstützt.)

Mit anderen Worten

Ich möchte die Deep Learning-Inferenzverarbeitung mit Kotlin / Java ausführen! Aber ich möchte keinen Lerncode mit deeplearning4j schreiben! !! Ich kann die Veröffentlichung des Tensorflow 2.0 Builds für Java kaum erwarten! !! !!

Ich hoffe, Sie können sich das als Verbindungsmaß für eine solche Person vorstellen.

Bibliotheksversion

Sowohl Tensorflow als auch Deeplearning4j ändern sich drastisch von Version zu Version, sodass die Wahrscheinlichkeit groß ist, dass sich das Verhalten je nach Version ändert.

tensorflow(Python) : 2.1.0
deeplearning4j(Kotlin/Java) : 1.0.0-beta7

Lerncode (Python, Tensorflow)

Als Beispiel wird der Python-Code zum Lernen von MNIST mit Tensorflow 2.0 unten gezeigt.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model

#Wenn es nicht float64 ist, funktioniert es unter deeplearning4j möglicherweise nicht richtig.
tf.keras.backend.set_floatx('float64')

#Sequentielles SubClassed-Modell kann nicht mit deeplearning4j importiert werden
def make_functional_model(data_size, target_size):
    inputs = Input(data_size)
    fc1 = Dense(512, activation='relu')
    fc2 = Dense(target_size, activation='softmax')
    outputs = fc2(fc1(inputs))
    return Model(inputs=inputs, outputs=outputs)

def train(dataset, model, criterion, optimizer):
    for data, target in dataset:
        with tf.GradientTape() as tape:
            output = model(data)
            loss = criterion(target, output)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

def main():

    #Vorbereitung für MNIST
    mnist = tf.keras.datasets.mnist
    (train_data, train_target), _ = mnist.load_data()
    train_data = train_data / 255.0
    train_data = np.reshape(train_data, (train_data.shape[0], -1))
    train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)

    #Modellvorbereitung
    data_size = train_data.shape[1]
    target_size = 10
    model = make_functional_model(data_size, target_size)

    #Lernen
    criterion = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam()
    for epoch in range(5):
        train(train_dataset, model, criterion, optimizer)
        
    #Das gespeicherte Modell kann nicht mit deeplearning4j importiert werden
    model.save('checkpoint.h5')

if __name__ == '__main__':
    main()

Das ist alles für den Lerncode. Außerdem wurde der normale Betrieb mit deeplearning4j nur beim "Speichern des Funktionsmodells im HDF5-Format" bestätigt. Es war nicht möglich, die Modellbeschreibung im sequentiellen oder unterklassierten Format oder im gespeicherten Format des gespeicherten Modells zu importieren.

Inferenzcode (Kotlin, Deep Learning4j)

Als nächstes kommt der Inferenzcode. Dies ist Kotlin mit deeplearning4j.

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport

fun main() {

    val mnist = MnistDataSetIterator(1, false, 123)
    val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
    //val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")

    mnist.forEach {
        val data = it.features
        val output = model.output(data)[0].toFloatVector()
        val pred = output.indexOf(output.max()!!)
    }
}

Importieren Sie einfach das Modell und schließen Sie mit der Ausgabefunktion. Es ist sehr einfach zu bedienen. Es gibt auch eine Importfunktion für das sequentielle Modell, aber wie oben erwähnt schlägt der Import fehl.

abschließend

In diesem Artikel habe ich die Verwendung des trainierten Modells von Tensorflow 2.0 mit Kotlin / Java vorgestellt. Ich hoffe, dass die Java-API von Tensorflow stabil verfügbar sein wird, und ich werde es jetzt ertragen. ..

Und lasst uns alle Deep Learning mit Kotlin machen! (Das möchte ich am liebsten sagen)

Recommended Posts

Verwendung des mit Tensorflow 2.0 trainierten Modells mit Kotlin / Java
[Java] [Maven3] Zusammenfassung der Verwendung von Maven3
Verwendung des Java-Frameworks mit AWS Lambda! ??
Verwendung der Java-API mit Lambda-Ausdrücken
[Java] Verwendung von Map
Verwendung von Java Optional
Verwendung der Java-Klasse
[Java] Verwendung von removeAll ()
Verwendung von Java Map
Verwendung von Java-Variablen
Zusammenfassung der Java-Kommunikations-API (1) Verwendung von Socket
Zusammenfassung der Java-Kommunikations-API (3) Verwendung von SocketChannel
Zusammenfassung der Java-Kommunikations-API (2) Verwendung von HttpUrlConnection
Verwendung von HttpClient (Post) von Java
[Java] Verwendung der Join-Methode
Verwendung von setDefaultCloseOperation () von JFrame
[Verarbeitung × Java] Verwendung von Variablen
Wie man mssql-tools mit alpine benutzt
[JavaFX] [Java8] Verwendung von GridPane
Verwendung von Klassenmethoden [Java]
[Java] Verwendung von List [ArrayList]
Wie verwende ich Klassen in Java?
[Verarbeitung × Java] Verwendung von Arrays
Verwendung von Java-Lambda-Ausdrücken
[Java] Verwendung der Math-Klasse
Verwendung des Java-Aufzählungstyps
Zusammenfassung der Verwendung des im IE festgelegten Proxy-Sets bei der Verbindung mit Java
Mehrsprachige Unterstützung für Java Verwendung des Gebietsschemas
[Java] Verwendung der File-Klasse
So kompilieren Sie Java mit VsCode & Ant
[Java] Fassen Sie zusammen, wie Sie mit der Methode equals vergleichen können
Verwendung von BootStrap mit Play Framework
[java] Zusammenfassung des Umgangs mit char
Verwendung der Submit-Methode (Java Silver)
[Leicht verständliche Erklärung! ] Verwendung der Java-Instanz
[Java] Verwendung der toString () -Methode
Studieren der Verwendung des Konstruktors (Java)
[Verarbeitung × Java] Verwendung der Schleife
Verwendung und Definition von Java-Klassen, Importieren
[Leicht verständliche Erklärung! ] Verwendung des Java-Polymorphismus
[Verarbeitung × Java] Verwendung der Klasse
Verwendung der Java Scanner-Klasse (Hinweis)
[Verarbeitung × Java] Verwendung der Funktion
[Leicht verständliche Erklärung! ] Verwendung von ArrayList [Java]
[Java] Verwendung der Calendar-Klasse
[Java] Erfahren Sie, wie Sie Optional richtig verwenden
[Leicht verständliche Erklärung! ] Verwendung von Java-Überladung
try-catch-finally Ausnahmebehandlung Verwendung von Java
[Leicht verständliche Erklärung! ] Verwendung der Java-Kapselung
[Java] So verwenden Sie Teilzeichenfolgen, um einen Teil einer Zeichenfolge auszuschneiden
[Java] [ibatis] So erhalten Sie 1-zu-N-Beziehungsdatensätze mit List <Map <>>
[Java] Beachten Sie, wie Sie RecyclerView verwenden und die animierte Swipe-Verarbeitung implementieren.
Verwenden Sie jenv, um mehrere Java-Versionen zu aktivieren