In diesem Artikel werde ich Ihnen zeigen, wie Sie das trainierte Modell von Tensorflow 2.0 mit Kotlin verwenden. Der Beispielcode ist nur Kotlin, aber ich denke, er wird in Java auf ähnliche Weise funktionieren.
In der diesmal eingeführten Methode wird KerasModelImport einer Bibliothek namens deeplearning4j verwendet. Da Tensorflow über eine Java-API verfügt, muss keine kleinere Bibliothek wie deeplearning4j ~~ verwendet werden. Da der Tensorflow2.0-Build für Java jedoch noch nicht verteilt wurde, wird deeplearning4j vorläufig verwendet. Ich werde es benutzen. (* Wenn Sie es selbst erstellen, können Sie möglicherweise die Java-API verwenden, die Tensorflow 2.0 unterstützt.)
Mit anderen Worten
Ich möchte die Deep Learning-Inferenzverarbeitung mit Kotlin / Java ausführen! Aber ich möchte keinen Lerncode mit deeplearning4j schreiben! !! Ich kann die Veröffentlichung des Tensorflow 2.0 Builds für Java kaum erwarten! !! !!
Ich hoffe, Sie können sich das als Verbindungsmaß für eine solche Person vorstellen.
Sowohl Tensorflow als auch Deeplearning4j ändern sich drastisch von Version zu Version, sodass die Wahrscheinlichkeit groß ist, dass sich das Verhalten je nach Version ändert.
tensorflow(Python) : 2.1.0
deeplearning4j(Kotlin/Java) : 1.0.0-beta7
Als Beispiel wird der Python-Code zum Lernen von MNIST mit Tensorflow 2.0 unten gezeigt.
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
#Wenn es nicht float64 ist, funktioniert es unter deeplearning4j möglicherweise nicht richtig.
tf.keras.backend.set_floatx('float64')
#Sequentielles SubClassed-Modell kann nicht mit deeplearning4j importiert werden
def make_functional_model(data_size, target_size):
inputs = Input(data_size)
fc1 = Dense(512, activation='relu')
fc2 = Dense(target_size, activation='softmax')
outputs = fc2(fc1(inputs))
return Model(inputs=inputs, outputs=outputs)
def train(dataset, model, criterion, optimizer):
for data, target in dataset:
with tf.GradientTape() as tape:
output = model(data)
loss = criterion(target, output)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
def main():
#Vorbereitung für MNIST
mnist = tf.keras.datasets.mnist
(train_data, train_target), _ = mnist.load_data()
train_data = train_data / 255.0
train_data = np.reshape(train_data, (train_data.shape[0], -1))
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)
#Modellvorbereitung
data_size = train_data.shape[1]
target_size = 10
model = make_functional_model(data_size, target_size)
#Lernen
criterion = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
for epoch in range(5):
train(train_dataset, model, criterion, optimizer)
#Das gespeicherte Modell kann nicht mit deeplearning4j importiert werden
model.save('checkpoint.h5')
if __name__ == '__main__':
main()
Das ist alles für den Lerncode. Außerdem wurde der normale Betrieb mit deeplearning4j nur beim "Speichern des Funktionsmodells im HDF5-Format" bestätigt. Es war nicht möglich, die Modellbeschreibung im sequentiellen oder unterklassierten Format oder im gespeicherten Format des gespeicherten Modells zu importieren.
Als nächstes kommt der Inferenzcode. Dies ist Kotlin mit deeplearning4j.
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport
fun main() {
val mnist = MnistDataSetIterator(1, false, 123)
val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
//val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")
mnist.forEach {
val data = it.features
val output = model.output(data)[0].toFloatVector()
val pred = output.indexOf(output.max()!!)
}
}
Importieren Sie einfach das Modell und schließen Sie mit der Ausgabefunktion. Es ist sehr einfach zu bedienen. Es gibt auch eine Importfunktion für das sequentielle Modell, aber wie oben erwähnt schlägt der Import fehl.
In diesem Artikel habe ich die Verwendung des trainierten Modells von Tensorflow 2.0 mit Kotlin / Java vorgestellt. Ich hoffe, dass die Java-API von Tensorflow stabil verfügbar sein wird, und ich werde es jetzt ertragen. ..
Und lasst uns alle Deep Learning mit Kotlin machen! (Das möchte ich am liebsten sagen)
Recommended Posts