Dans cet article, je vais vous montrer comment utiliser le modèle entraîné de tensorflow 2.0 avec Kotlin. L'exemple de code est uniquement Kotlin, mais je pense qu'il fonctionnera en Java de la même manière.
Dans la méthode introduite cette fois, KerasModelImport d'une bibliothèque appelée deeplearning4j est utilisée. Étant donné que tensorflow a une API Java, il n'est pas nécessaire d'utiliser une bibliothèque mineure telle que deeplearning4j ~~, mais comme la version tensorflow2.0 pour Java n'a pas encore été distribuée, deeplearning4j est provisoirement utilisé. Je vais l'utiliser. (* Si vous le construisez vous-même, vous pourrez peut-être utiliser l'API Java compatible avec tensorflow 2.0)
En d'autres termes
Je souhaite exécuter le traitement d'inférence Deep Learning avec Kotlin / Java! Mais je ne veux pas écrire de code d'apprentissage avec deeplearning4j! !! J'ai hâte de voir la distribution de la version tensorflow 2.0 pour Java! !! !!
J'espère que vous pouvez le considérer comme une mesure de connexion pour une telle personne.
Tensorflow et deeplearning4j changent radicalement d'une version à l'autre, il y a donc une forte possibilité que le comportement change en fonction de la version.
tensorflow(Python) : 2.1.0
deeplearning4j(Kotlin/Java) : 1.0.0-beta7
À titre d'exemple, le code Python pour l'apprentissage de MNIST avec tensorflow 2.0 est présenté ci-dessous.
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras import Model
#Si ce n'est pas float64, cela peut ne pas fonctionner correctement sur deeplearning4j.
tf.keras.backend.set_floatx('float64')
#L'importation du modèle séquentiel sous-classé échoue avec deeplearning4j
def make_functional_model(data_size, target_size):
inputs = Input(data_size)
fc1 = Dense(512, activation='relu')
fc2 = Dense(target_size, activation='softmax')
outputs = fc2(fc1(inputs))
return Model(inputs=inputs, outputs=outputs)
def train(dataset, model, criterion, optimizer):
for data, target in dataset:
with tf.GradientTape() as tape:
output = model(data)
loss = criterion(target, output)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
def main():
#Préparation pour MNIST
mnist = tf.keras.datasets.mnist
(train_data, train_target), _ = mnist.load_data()
train_data = train_data / 255.0
train_data = np.reshape(train_data, (train_data.shape[0], -1))
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_target)).batch(32)
#Préparation du modèle
data_size = train_data.shape[1]
target_size = 10
model = make_functional_model(data_size, target_size)
#Apprentissage
criterion = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
for epoch in range(5):
train(train_dataset, model, criterion, optimizer)
#Le modèle enregistré ne parvient pas à importer avec deeplearning4j
model.save('checkpoint.h5')
if __name__ == '__main__':
main()
C'est tout pour le code d'apprentissage. De plus, le fonctionnement normal a été confirmé avec deeplearning4j uniquement lorsque "Enregistrer le modèle fonctionnel au format hdf5". Il n'a pas été possible d'importer dans la description du modèle au format séquentiel ou sous-classé ou au format enregistré du modèle enregistré.
Vient ensuite le code d'inférence. C'est Kotlin qui utilise deeplearning4j.
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport
fun main() {
val mnist = MnistDataSetIterator(1, false, 123)
val model = KerasModelImport.importKerasModelAndWeights("checkpoint.h5")
//val model = KerasModelImport.importKerasSequentialModelAndWeights("checkpoint.h5")
mnist.forEach {
val data = it.features
val output = model.output(data)[0].toFloatVector()
val pred = output.indexOf(output.max()!!)
}
}
Il suffit d'importer le modèle et de déduire avec la fonction de sortie. C'est très simple à utiliser. Il existe également une fonction d'importation pour le modèle séquentiel, mais comme mentionné ci-dessus, l'importation échoue.
Dans cet article, j'ai présenté comment utiliser le modèle entraîné de tensorflow 2.0 avec Kotlin / Java. J'espère que l'API Java de tensorflow sera disponible de manière stable, et je la supporterai maintenant. ..
Et faisons tous du Deep Learning avec Kotlin! (C'est ce que je veux dire le plus)
Recommended Posts